From a4c83cb1e4b066cd60264b6572fd3e51d160d26a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 21 Jul 2015 19:14:07 -0700 Subject: [PATCH 001/219] [SPARK-9154][SQL] Rename formatString to format_string. Also make format_string the canonical form, rather than printf. Author: Reynold Xin Closes #7579 from rxin/format_strings and squashes the following commits: 53ee54f [Reynold Xin] Fixed unit tests. 52357e1 [Reynold Xin] Add format_string alias. b40a42a [Reynold Xin] [SPARK-9154][SQL] Rename formatString to format_string. --- .../catalyst/analysis/FunctionRegistry.scala | 3 ++- .../expressions/stringOperations.scala | 13 +++++-------- .../expressions/StringExpressionsSuite.scala | 14 +++++++------- .../scala/org/apache/spark/sql/functions.scala | 18 +++--------------- .../spark/sql/StringFunctionsSuite.scala | 12 +----------- 5 files changed, 18 insertions(+), 42 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e3d8d2adf2135..9c349838c28a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -168,7 +168,8 @@ object FunctionRegistry { expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), - expression[StringFormat]("printf"), + expression[FormatString]("format_string"), + expression[FormatString]("printf"), expression[StringRPad]("rpad"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 1f18a6e9ff8a5..cf187ad5a0a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -526,29 +526,26 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { +case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { - require(children.nonEmpty, "printf() should take at least 1 argument") + require(children.nonEmpty, "format_string() should take at least 1 argument") override def foldable: Boolean = children.forall(_.foldable) override def nullable: Boolean = children(0).nullable override def dataType: DataType = StringType - private def format: Expression = children(0) - private def args: Seq[Expression] = children.tail override def inputTypes: Seq[AbstractDataType] = StringType :: List.fill(children.size - 1)(AnyDataType) - override def eval(input: InternalRow): Any = { - val pattern = format.eval(input) + val pattern = children(0).eval(input) if (pattern == null) { null } else { val sb = new StringBuffer() val formatter = new java.util.Formatter(sb, Locale.US) - val arglist = args.map(_.eval(input).asInstanceOf[AnyRef]) + val arglist = children.tail.map(_.eval(input).asInstanceOf[AnyRef]) formatter.format(pattern.asInstanceOf[UTF8String].toString, arglist: _*) UTF8String.fromString(sb.toString) @@ -591,7 +588,7 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC """ } - override def prettyName: String = "printf" + override def prettyName: String = "format_string" } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3c2d88731beb4..3d294fda5d103 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,16 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa")), "aa", create_row(null)) + checkEvaluation(FormatString(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(FormatString(Literal("aa%d%s"), 12, "cc"), "aa12cc") - checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation(FormatString(Literal.create(null, StringType), 12, "cc"), null) checkEvaluation( - StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + FormatString(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") checkEvaluation( - StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") + FormatString(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e5ff8ae7e3179..28159cbd5ab96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1742,26 +1742,14 @@ object functions { def rtrim(e: Column): Column = StringTrimRight(e.expr) /** - * Format strings in printf-style. + * Formats the arguments in printf-style and returns the result as a string column. * * @group string_funcs * @since 1.5.0 */ @scala.annotation.varargs - def formatString(format: Column, arguments: Column*): Column = { - StringFormat((format +: arguments).map(_.expr): _*) - } - - /** - * Format strings in printf-style. - * NOTE: `format` is the string value of the formatter, not column name. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def formatString(format: String, arguNames: String*): Column = { - StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) + def format_string(format: String, arguments: Column*): Column = { + FormatString((lit(format) +: arguments).map(_.expr): _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3702e73b4e74f..0f9c986f649a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -126,22 +126,12 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") checkAnswer( - df.select(formatString("aa%d%s", "b", "c")), + df.select(format_string("aa%d%s", $"b", $"c")), Row("aa123cc")) checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) - - val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") - - checkAnswer( - df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) - - checkAnswer( - df2.selectExpr("printf(a, b, c)"), - Row("aa123cc")) } test("string instr function") { From 63f4bcc73f5a09c1790cc3c333f08b18609de6a4 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Tue, 21 Jul 2015 22:50:27 -0700 Subject: [PATCH 002/219] [SPARK-9121] [SPARKR] Get rid of the warnings about `no visible global function definition` in SparkR [[SPARK-9121] Get rid of the warnings about `no visible global function definition` in SparkR - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9121) ## The Result of `dev/lint-r` [The result of lint-r for SPARK-9121 at the revision:1ddd0f2f1688560f88470e312b72af04364e2d49 when I have sent a PR](https://gist.github.com/yu-iskw/6f55953425901725edf6) Author: Yu ISHIKAWA Closes #7567 from yu-iskw/SPARK-9121 and squashes the following commits: c8cfd63 [Yu ISHIKAWA] Fix the typo b1f19ed [Yu ISHIKAWA] Add a validate statement for local SparkR 1a03987 [Yu ISHIKAWA] Load the `testthat` package in `dev/lint-r.R`, instead of using the full path of function. 3a5e0ab [Yu ISHIKAWA] [SPARK-9121][SparkR] Get rid of the warnings about `no visible global function definition` in SparkR --- dev/lint-r.R | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dev/lint-r.R b/dev/lint-r.R index dcb1a184291e1..48bd6246096ae 100644 --- a/dev/lint-r.R +++ b/dev/lint-r.R @@ -15,15 +15,21 @@ # limitations under the License. # +argv <- commandArgs(TRUE) +SPARK_ROOT_DIR <- as.character(argv[1]) + # Installs lintr from Github. # NOTE: The CRAN's version is too old to adapt to our rules. if ("lintr" %in% row.names(installed.packages()) == FALSE) { devtools::install_github("jimhester/lintr") } -library(lintr) -argv <- commandArgs(TRUE) -SPARK_ROOT_DIR <- as.character(argv[1]) +library(lintr) +library(methods) +library(testthat) +if (! library(SparkR, lib.loc = file.path(SPARK_ROOT_DIR, "R", "lib"), logical.return = TRUE)) { + stop("You should install SparkR in a local directory with `R/install-dev.sh`.") +} path.to.package <- file.path(SPARK_ROOT_DIR, "R", "pkg") lint_package(path.to.package, cache = FALSE) From f4785f5b82c57bce41d3dc26ed9e3c9e794c7558 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 21 Jul 2015 23:00:13 -0700 Subject: [PATCH 003/219] [SPARK-9232] [SQL] Duplicate code in JSONRelation Author: Andrew Or Closes #7576 from andrewor14/clean-up-json-relation and squashes the following commits: ea80803 [Andrew Or] Clean up duplicate code --- .../apache/spark/sql/json/JSONRelation.scala | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 25802d054ac00..922794ac9aac5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.json import java.io.IOException -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException @@ -87,20 +87,7 @@ private[sql] class DefaultSource case SaveMode.Append => sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}") case SaveMode.Overwrite => { - var success: Boolean = false - try { - success = fs.delete(filesystemPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table.") - } + JSONRelation.delete(filesystemPath, fs) true } case SaveMode.ErrorIfExists => @@ -195,20 +182,7 @@ private[sql] class JSONRelation( if (overwrite) { if (fs.exists(filesystemPath)) { - var success: Boolean = false - try { - success = fs.delete(filesystemPath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table:\n${e.toString}") - } - if (!success) { - throw new IOException( - s"Unable to clear output directory ${filesystemPath.toString} prior" - + s" to writing to JSON table.") - } + JSONRelation.delete(filesystemPath, fs) } // Write the data. data.toJSON.saveAsTextFile(filesystemPath.toString) @@ -228,3 +202,21 @@ private[sql] class JSONRelation( case _ => false } } + +private object JSONRelation { + + /** Delete the specified directory to overwrite it with new JSON data. */ + def delete(dir: Path, fs: FileSystem): Unit = { + var success: Boolean = false + val failMessage = s"Unable to clear output directory $dir prior to writing to JSON table" + try { + success = fs.delete(dir, true /* recursive */) + } catch { + case e: IOException => + throw new IOException(s"$failMessage\n${e.toString}") + } + if (!success) { + throw new IOException(failMessage) + } + } +} From c03299a18b4e076cabb4b7833a1e7632c5c0dabe Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 21 Jul 2015 23:26:11 -0700 Subject: [PATCH 004/219] [SPARK-4233] [SPARK-4367] [SPARK-3947] [SPARK-3056] [SQL] Aggregation Improvement This is the first PR for the aggregation improvement, which is tracked by https://issues.apache.org/jira/browse/SPARK-4366 (umbrella JIRA). This PR contains work for its subtasks, SPARK-3056, SPARK-3947, SPARK-4233, and SPARK-4367. This PR introduces a new code path for evaluating aggregate functions. This code path is guarded by `spark.sql.useAggregate2` and by default the value of this flag is true. This new code path contains: * A new aggregate function interface (`AggregateFunction2`) and 7 built-int aggregate functions based on this new interface (`AVG`, `COUNT`, `FIRST`, `LAST`, `MAX`, `MIN`, `SUM`) * A UDAF interface (`UserDefinedAggregateFunction`) based on the new code path and two example UDAFs (`MyDoubleAvg` and `MyDoubleSum`). * A sort-based aggregate operator (`Aggregate2Sort`) for the new aggregate function interface . * A sort-based aggregate operator (`FinalAndCompleteAggregate2Sort`) for distinct aggregations (for distinct aggregations the query plan will use `Aggregate2Sort` and `FinalAndCompleteAggregate2Sort` together). With this change, `spark.sql.useAggregate2` is `true`, the flow of compiling an aggregation query is: 1. Our analyzer looks up functions and returns aggregate functions built based on the old aggregate function interface. 2. When our planner is compiling the physical plan, it tries try to convert all aggregate functions to the ones built based on the new interface. The planner will fallback to the old code path if any of the following two conditions is true: * code-gen is disabled. * there is any function that cannot be converted (right now, Hive UDAFs). * the schema of grouping expressions contain any complex data type. * There are multiple distinct columns. Right now, the new code path handles a single distinct column in the query (you can have multiple aggregate functions using that distinct column). For a query having a aggregate function with DISTINCT and regular aggregate functions, the generated plan will do partial aggregations for those regular aggregate function. Thanks chenghao-intel for his initial work on it. Author: Yin Huai Author: Michael Armbrust Closes #7458 from yhuai/UDAF and squashes the following commits: 7865f5e [Yin Huai] Put the catalyst expression in the comment of the generated code for it. b04d6c8 [Yin Huai] Remove unnecessary change. f1d5901 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 35b0520 [Yin Huai] Use semanticEquals to replace grouping expressions in the output of the aggregate operator. 3b43b24 [Yin Huai] bug fix. 00eb298 [Yin Huai] Make it compile. a3ca551 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF e0afca3 [Yin Huai] Gracefully fallback to old aggregation code path. 8a8ac4a [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 88c7d4d [Yin Huai] Enable spark.sql.useAggregate2 by default for testing purpose. dc96fd1 [Yin Huai] Many updates: 85c9c4b [Yin Huai] newline. 43de3de [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF c3614d7 [Yin Huai] Handle single distinct column. 68b8ee9 [Yin Huai] Support single distinct column set. WIP 3013579 [Yin Huai] Format. d678aee [Yin Huai] Remove AggregateExpressionSuite.scala since our built-in aggregate functions will be based on AlgebraicAggregate and we need to have another way to test it. e243ca6 [Yin Huai] Add aggregation iterators. a101960 [Yin Huai] Change MyJavaUDAF to MyDoubleSum. 594cdf5 [Yin Huai] Change existing AggregateExpression to AggregateExpression1 and add an AggregateExpression as the common interface for both AggregateExpression1 and AggregateExpression2. 380880f [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 0a827b3 [Yin Huai] Add comments and doc. Move some classes to the right places. a19fea6 [Yin Huai] Add UDAF interface. 262d4c4 [Yin Huai] Make it compile. b2e358e [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 6edb5ac [Yin Huai] Format update. 70b169c [Yin Huai] Remove groupOrdering. 4721936 [Yin Huai] Add CheckAggregateFunction to extendedCheckRules. d821a34 [Yin Huai] Cleanup. 32aea9c [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 5b46d41 [Yin Huai] Bug fix. aff9534 [Yin Huai] Make Aggregate2Sort work with both algebraic AggregateFunctions and non-algebraic AggregateFunctions. 2857b55 [Yin Huai] Merge remote-tracking branch 'upstream/master' into UDAF 4435f20 [Yin Huai] Add ConvertAggregateFunction to HiveContext's analyzer. 1b490ed [Michael Armbrust] make hive test 8cfa6a9 [Michael Armbrust] add test 1b0bb3f [Yin Huai] Do not bind references in AlgebraicAggregate and use code gen for all places. 072209f [Yin Huai] Bug fix: Handle expressions in grouping columns that are not attribute references. f7d9e54 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into UDAF 39ee975 [Yin Huai] Code cleanup: Remove unnecesary AttributeReferences. b7720ba [Yin Huai] Add an analysis rule to convert aggregate function to the new version. 5c00f3f [Michael Armbrust] First draft of codegen 6bbc6ba [Michael Armbrust] now with correct answers\! f7996d0 [Michael Armbrust] Add AlgebraicAggregate dded1c5 [Yin Huai] wip --- .../apache/spark/sql/catalyst/SqlParser.scala | 3 +- .../sql/catalyst/analysis/Analyzer.scala | 24 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 + .../sql/catalyst/analysis/unresolved.scala | 5 +- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 3 +- .../expressions/aggregate/functions.scala | 292 +++++++ .../expressions/aggregate/interfaces.scala | 206 +++++ .../sql/catalyst/expressions/aggregates.scala | 100 +-- .../codegen/GenerateMutableProjection.scala | 21 +- .../sql/catalyst/planning/patterns.scala | 4 +- .../plans/logical/basicOperators.scala | 1 + .../scala/org/apache/spark/sql/SQLConf.scala | 5 + .../org/apache/spark/sql/SQLContext.scala | 4 + .../apache/spark/sql/UDAFRegistration.scala | 35 + .../spark/sql/execution/Aggregate.scala | 12 +- .../apache/spark/sql/execution/Exchange.scala | 11 +- .../sql/execution/GeneratedAggregate.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 100 ++- .../aggregate/aggregateOperators.scala | 173 ++++ .../aggregate/sortBasedIterators.scala | 749 ++++++++++++++++++ .../spark/sql/execution/aggregate/utils.scala | 364 +++++++++ .../sql/expressions/aggregate/udaf.scala | 280 +++++++ .../org/apache/spark/sql/functions.scala | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +- .../spark/sql/execution/PlannerSuite.scala | 26 +- .../HiveWindowFunctionQuerySuite.scala | 1 + .../SortMergeCompatibilitySuite.scala | 7 + .../apache/spark/sql/hive/HiveContext.scala | 1 + .../org/apache/spark/sql/hive/HiveQl.scala | 7 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 8 +- .../spark/sql/hive/aggregate/MyDoubleAvg.java | 107 +++ .../spark/sql/hive/aggregate/MyDoubleSum.java | 100 +++ ...f_unhex-0-50131c0ba7b7a6b65c789a5a8497bada | 1 + ...f_unhex-1-11eb3cc5216d5446f4165007203acc47 | 1 + ...f_unhex-2-a660886085b8651852b9b77934848ae4 | 14 + ...f_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e | 1 + ...f_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 | 1 + .../execution/AggregationQuerySuite.scala | 507 ++++++++++++ 39 files changed, 3087 insertions(+), 100 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala create mode 100644 sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java create mode 100644 sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e create mode 100644 sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index d4ef04c2294a2..c04bd6cd85187 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -266,11 +266,12 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ - { case udfName ~ exprs => UnresolvedFunction(udfName, exprs) } + { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { case "sum" => SumDistinct(exprs.head) case "count" => CountDistinct(exprs) + case name => UnresolvedFunction(name, exprs, isDistinct = true) case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e58f3f64947f3..8cadbc57e87e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -277,7 +278,7 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil @@ -517,9 +518,26 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { - case u @ UnresolvedFunction(name, children) => + case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { - registry.lookupFunction(name, children) + registry.lookupFunction(name, children) match { + // We get an aggregate function built based on AggregateFunction2 interface. + // So, we wrap it in AggregateExpression2. + case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) + // Currently, our old aggregate function interface supports SUM(DISTINCT ...) + // and COUTN(DISTINCT ...). + case sumDistinct: SumDistinct => sumDistinct + case countDistinct: CountDistinct => countDistinct + // DISTINCT is not meaningful with Max and Min. + case max: Max if isDistinct => max + case min: Min if isDistinct => min + // For other aggregate functions, DISTINCT keyword is not supported for now. + // Once we converted to the new code path, we will allow using DISTINCT keyword. + case other if isDistinct => + failAnalysis(s"$name does not support DISTINCT keyword.") + // If it does not have DISTINCT keyword, we will return it as is. + case other => other + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c7f9713344c50..c203fcecf20fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 0daee1990a6e0..03da45b09f928 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -73,7 +73,10 @@ object UnresolvedAttribute { def quoted(name: String): UnresolvedAttribute = new UnresolvedAttribute(Seq(name)) } -case class UnresolvedFunction(name: String, children: Seq[Expression]) +case class UnresolvedFunction( + name: String, + children: Seq[Expression], + isDistinct: Boolean) extends Expression with Unevaluable { override def dataType: DataType = throw new UnresolvedException(this, "dataType") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index b09aea03318da..b10a3c877434b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal]" + override def toString: String = s"input[$ordinal, $dataType]" override def eval(input: InternalRow): Any = input(ordinal) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index aada25276adb7..29ae47e842ddb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -96,7 +96,8 @@ abstract class Expression extends TreeNode[Expression] { val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) ve.code = genCode(ctx, ve) - ve + // Add `this` in the comment. + ve.copy(s"/* $this */\n" + ve.code) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala new file mode 100644 index 0000000000000..b924af4cc84d8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +case class Average(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Once we remove the old code path, we can use our analyzer to cast NullType + // to the default data type of the NumericType. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => DecimalType.Unlimited + case _ => DoubleType + } + + private val sumDataType = child.dataType match { + case _ @ DecimalType() => DecimalType.Unlimited + case _ => DoubleType + } + + private val currentSum = AttributeReference("currentSum", sumDataType)() + private val currentCount = AttributeReference("currentCount", LongType)() + + override val bufferAttributes = currentSum :: currentCount :: Nil + + override val initialValues = Seq( + /* currentSum = */ Cast(Literal(0), sumDataType), + /* currentCount = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* currentSum = */ + Add( + currentSum, + Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), + /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + ) + + override val mergeExpressions = Seq( + /* currentSum = */ currentSum.left + currentSum.right, + /* currentCount = */ currentCount.left + currentCount.right + ) + + // If all input are nulls, currentCount will be 0 and we will get null after the division. + override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType) +} + +case class Count(child: Expression) extends AlgebraicAggregate { + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = false + + // Return data type. + override def dataType: DataType = LongType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val currentCount = AttributeReference("currentCount", LongType)() + + override val bufferAttributes = currentCount :: Nil + + override val initialValues = Seq( + /* currentCount = */ Literal(0L) + ) + + override val updateExpressions = Seq( + /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + ) + + override val mergeExpressions = Seq( + /* currentCount = */ currentCount.left + currentCount.right + ) + + override val evaluateExpression = Cast(currentCount, LongType) +} + +case class First(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // First is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val first = AttributeReference("first", child.dataType)() + + override val bufferAttributes = first :: Nil + + override val initialValues = Seq( + /* first = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* first = */ If(IsNull(first), child, first) + ) + + override val mergeExpressions = Seq( + /* first = */ If(IsNull(first.left), first.right, first.left) + ) + + override val evaluateExpression = first +} + +case class Last(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Last is not a deterministic function. + override def deterministic: Boolean = false + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val last = AttributeReference("last", child.dataType)() + + override val bufferAttributes = last :: Nil + + override val initialValues = Seq( + /* last = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* last = */ If(IsNull(child), last, child) + ) + + override val mergeExpressions = Seq( + /* last = */ If(IsNull(last.right), last.left, last.right) + ) + + override val evaluateExpression = last +} + +case class Max(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val max = AttributeReference("max", child.dataType)() + + override val bufferAttributes = max :: Nil + + override val initialValues = Seq( + /* max = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) + ) + + override val mergeExpressions = { + val greatest = Greatest(Seq(max.left, max.right)) + Seq( + /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) + ) + } + + override val evaluateExpression = max +} + +case class Min(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = child.dataType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + private val min = AttributeReference("min", child.dataType)() + + override val bufferAttributes = min :: Nil + + override val initialValues = Seq( + /* min = */ Literal.create(null, child.dataType) + ) + + override val updateExpressions = Seq( + /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) + ) + + override val mergeExpressions = { + val least = Least(Seq(min.left, min.right)) + Seq( + /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) + ) + } + + override val evaluateExpression = min +} + +case class Sum(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) + + private val resultType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => DecimalType.Unlimited + case _ => child.dataType + } + + private val sumDataType = child.dataType match { + case _ @ DecimalType() => DecimalType.Unlimited + case _ => child.dataType + } + + private val currentSum = AttributeReference("currentSum", sumDataType)() + + private val zero = Cast(Literal(0), sumDataType) + + override val bufferAttributes = currentSum :: Nil + + override val initialValues = Seq( + /* currentSum = */ Literal.create(null, sumDataType) + ) + + override val updateExpressions = Seq( + /* currentSum = */ + Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum)) + ) + + override val mergeExpressions = { + val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType)) + Seq( + /* currentSum = */ + Coalesce(Seq(add, currentSum.left)) + ) + } + + override val evaluateExpression = Cast(currentSum, resultType) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala new file mode 100644 index 0000000000000..577ede73cb01f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** The mode of an [[AggregateFunction1]]. */ +private[sql] sealed trait AggregateMode + +/** + * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation. + * This function updates the given aggregation buffer with the original input of this + * function. When it has processed all input rows, the aggregation buffer is returned. + */ +private[sql] case object Partial extends AggregateMode + +/** + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function. + * This function updates the given aggregation buffer by merging multiple aggregation buffers. + * When it has processed all input rows, the aggregation buffer is returned. + */ +private[sql] case object PartialMerge extends AggregateMode + +/** + * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function and the generate final result. + * This function updates the given aggregation buffer by merging multiple aggregation buffers. + * When it has processed all input rows, the final result of this function is returned. + */ +private[sql] case object Final extends AggregateMode + +/** + * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly + * from original input rows without any partial aggregation. + * This function updates the given aggregation buffer with the original input of this + * function. When it has processed all input rows, the final result of this function is returned. + */ +private[sql] case object Complete extends AggregateMode + +/** + * A place holder expressions used in code-gen, it does not change the corresponding value + * in the row. + */ +private[sql] case object NoOp extends Expression with Unevaluable { + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = { + throw new TreeNodeException( + this, s"No function to evaluate expression. type: ${this.nodeName}") + } + override def dataType: DataType = NullType + override def children: Seq[Expression] = Nil +} + +/** + * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field + * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. + * @param aggregateFunction + * @param mode + * @param isDistinct + */ +private[sql] case class AggregateExpression2( + aggregateFunction: AggregateFunction2, + mode: AggregateMode, + isDistinct: Boolean) extends AggregateExpression { + + override def children: Seq[Expression] = aggregateFunction :: Nil + override def dataType: DataType = aggregateFunction.dataType + override def foldable: Boolean = false + override def nullable: Boolean = aggregateFunction.nullable + + override def references: AttributeSet = { + val childReferemces = mode match { + case Partial | Complete => aggregateFunction.references.toSeq + case PartialMerge | Final => aggregateFunction.bufferAttributes + } + + AttributeSet(childReferemces) + } + + override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" +} + +abstract class AggregateFunction2 + extends Expression with ImplicitCastInputTypes { + + self: Product => + + /** An aggregate function is not foldable. */ + override def foldable: Boolean = false + + /** + * The offset of this function's buffer in the underlying buffer shared with other functions. + */ + var bufferOffset: Int = 0 + + /** The schema of the aggregation buffer. */ + def bufferSchema: StructType + + /** Attributes of fields in bufferSchema. */ + def bufferAttributes: Seq[AttributeReference] + + /** Clones bufferAttributes. */ + def cloneBufferAttributes: Seq[Attribute] + + /** + * Initializes its aggregation buffer located in `buffer`. + * It will use bufferOffset to find the starting point of + * its buffer in the given `buffer` shared with other functions. + */ + def initialize(buffer: MutableRow): Unit + + /** + * Updates its aggregation buffer located in `buffer` based on the given `input`. + * It will use bufferOffset to find the starting point of its buffer in the given `buffer` + * shared with other functions. + */ + def update(buffer: MutableRow, input: InternalRow): Unit + + /** + * Updates its aggregation buffer located in `buffer1` by combining intermediate results + * in the current buffer and intermediate results from another buffer `buffer2`. + * It will use bufferOffset to find the starting point of its buffer in the given `buffer1` + * and `buffer2`. + */ + def merge(buffer1: MutableRow, buffer2: InternalRow): Unit + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + +/** + * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. + */ +abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable { + self: Product => + + val initialValues: Seq[Expression] + val updateExpressions: Seq[Expression] + val mergeExpressions: Seq[Expression] + val evaluateExpression: Expression + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + /** + * A helper class for representing an attribute used in merging two + * aggregation buffers. When merging two buffers, `bufferLeft` and `bufferRight`, + * we merge buffer values and then update bufferLeft. A [[RichAttribute]] + * of an [[AttributeReference]] `a` has two functions `left` and `right`, + * which represent `a` in `bufferLeft` and `bufferRight`, respectively. + * @param a + */ + implicit class RichAttribute(a: AttributeReference) { + /** Represents this attribute at the mutable buffer side. */ + def left: AttributeReference = a + + /** Represents this attribute at the input buffer side (the data value is read-only). */ + def right: AttributeReference = cloneBufferAttributes(bufferAttributes.indexOf(a)) + } + + /** An AlgebraicAggregate's bufferSchema is derived from bufferAttributes. */ + override def bufferSchema: StructType = StructType.fromAttributes(bufferAttributes) + + override def initialize(buffer: MutableRow): Unit = { + var i = 0 + while (i < bufferAttributes.size) { + buffer(i + bufferOffset) = initialValues(i).eval() + i += 1 + } + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's update should not be called directly") + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's merge should not be called directly") + } + + override def eval(buffer: InternalRow): Any = { + throw new UnsupportedOperationException( + "AlgebraicAggregate's eval should not be called directly") + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index d705a1286065c..e07c920a41d0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -27,7 +27,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet -trait AggregateExpression extends Expression with Unevaluable { +trait AggregateExpression extends Expression with Unevaluable + +trait AggregateExpression1 extends AggregateExpression { /** * Aggregate expressions should not be foldable. @@ -38,7 +40,7 @@ trait AggregateExpression extends Expression with Unevaluable { * Creates a new instance that can be used to compute this aggregate expression for a group * of input rows/ */ - def newInstance(): AggregateFunction + def newInstance(): AggregateFunction1 } /** @@ -54,10 +56,10 @@ case class SplitEvaluation( partialEvaluations: Seq[NamedExpression]) /** - * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples. + * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. * These partial evaluations can then be combined to compute the actual answer. */ -trait PartialAggregate extends AggregateExpression { +trait PartialAggregate1 extends AggregateExpression1 { /** * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. @@ -67,13 +69,13 @@ trait PartialAggregate extends AggregateExpression { /** * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression]] with an algorithm that will be used to compute one specific result. + * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. */ -abstract class AggregateFunction - extends LeafExpression with AggregateExpression with Serializable { +abstract class AggregateFunction1 + extends LeafExpression with AggregateExpression1 with Serializable { /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression + val base: AggregateExpression1 override def nullable: Boolean = base.nullable override def dataType: DataType = base.dataType @@ -81,12 +83,12 @@ abstract class AggregateFunction def update(input: InternalRow): Unit // Do we really need this? - override def newInstance(): AggregateFunction = { + override def newInstance(): AggregateFunction1 = { makeCopy(productIterator.map { case a: AnyRef => a }.toArray) } } -case class Min(child: Expression) extends UnaryExpression with PartialAggregate { +case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -102,7 +104,7 @@ case class Min(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForOrderingExpr(child.dataType, "function min") } -case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -119,7 +121,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMin.value } -case class Max(child: Expression) extends UnaryExpression with PartialAggregate { +case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -135,7 +137,7 @@ case class Max(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForOrderingExpr(child.dataType, "function max") } -case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) @@ -152,7 +154,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr override def eval(input: InternalRow): Any = currentMax.value } -case class Count(child: Expression) extends UnaryExpression with PartialAggregate { +case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -165,7 +167,7 @@ case class Count(child: Expression) extends UnaryExpression with PartialAggregat override def newInstance(): CountFunction = new CountFunction(child, this) } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var count: Long = _ @@ -180,7 +182,7 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = count } -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { +case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -200,8 +202,8 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate case class CountDistinctFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -220,7 +222,7 @@ case class CountDistinctFunction( override def eval(input: InternalRow): Any = seen.size.toLong } -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { +case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = expressions @@ -233,8 +235,8 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress case class CollectHashSetFunction( @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -255,7 +257,7 @@ case class CollectHashSetFunction( } } -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { +case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = inputSet :: Nil @@ -269,8 +271,8 @@ case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression case class CombineSetsAndCountFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -305,7 +307,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { } case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: DataType = HyperLogLogUDT @@ -317,9 +319,9 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) case class ApproxCountDistinctPartitionFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -335,7 +337,7 @@ case class ApproxCountDistinctPartitionFunction( } case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression { + extends UnaryExpression with AggregateExpression1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -347,9 +349,9 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) case class ApproxCountDistinctMergeFunction( expr: Expression, - base: AggregateExpression, + base: AggregateExpression1, relativeSD: Double) - extends AggregateFunction { + extends AggregateFunction1 { def this() = this(null, null, 0) // Required for serialization. private val hyperLogLog = new HyperLogLog(relativeSD) @@ -363,7 +365,7 @@ case class ApproxCountDistinctMergeFunction( } case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends UnaryExpression with PartialAggregate { + extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = false override def dataType: LongType.type = LongType @@ -381,7 +383,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) } -case class Average(child: Expression) extends UnaryExpression with PartialAggregate { +case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { override def prettyName: String = "avg" @@ -427,8 +429,8 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg TypeUtils.checkForNumericExpr(child.dataType, "function average") } -case class AverageFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class AverageFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -474,7 +476,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class Sum(child: Expression) extends UnaryExpression with PartialAggregate { +case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true @@ -509,7 +511,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate TypeUtils.checkForNumericExpr(child.dataType, "function sum") } -case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. private val calcType = @@ -554,7 +556,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr * <-- null <-- no data * null <-- null <-- no data */ -case class CombineSum(child: Expression) extends AggregateExpression { +case class CombineSum(child: Expression) extends AggregateExpression1 { def this() = this(null) override def children: Seq[Expression] = child :: Nil @@ -564,8 +566,8 @@ case class CombineSum(child: Expression) extends AggregateExpression { override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) } -case class CombineSumFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class CombineSumFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -601,7 +603,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } -case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate { +case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { def this() = this(null) override def nullable: Boolean = true @@ -627,8 +629,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") } -case class SumDistinctFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -653,7 +655,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { def this() = this(null, null) override def children: Seq[Expression] = inputSet :: Nil @@ -667,8 +669,8 @@ case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends Agg case class CombineSetsAndSumFunction( @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { + @transient base: AggregateExpression1) + extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. @@ -695,7 +697,7 @@ case class CombineSetsAndSumFunction( } } -case class First(child: Expression) extends UnaryExpression with PartialAggregate { +case class First(child: Expression) extends UnaryExpression with PartialAggregate1 { override def nullable: Boolean = true override def dataType: DataType = child.dataType override def toString: String = s"FIRST($child)" @@ -709,7 +711,7 @@ case class First(child: Expression) extends UnaryExpression with PartialAggregat override def newInstance(): FirstFunction = new FirstFunction(child, this) } -case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class FirstFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null @@ -723,7 +725,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = result } -case class Last(child: Expression) extends UnaryExpression with PartialAggregate { +case class Last(child: Expression) extends UnaryExpression with PartialAggregate1 { override def references: AttributeSet = child.references override def nullable: Boolean = true override def dataType: DataType = child.dataType @@ -738,7 +740,7 @@ case class Last(child: Expression) extends UnaryExpression with PartialAggregate override def newInstance(): LastFunction = new LastFunction(child, this) } -case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { +case class LastFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { def this() = this(null, null) // Required for serialization. var result: Any = null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 03b4b3c216f49..d838268f46956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import scala.collection.mutable.ArrayBuffer @@ -38,15 +39,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCode = expressions.zipWithIndex.map { case (e, i) => - val evaluationCode = e.gen(ctx) - evaluationCode.code + - s""" - if(${evaluationCode.isNull}) - mutableRow.setNullAt($i); - else - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; - """ + val projectionCode = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => + val evaluationCode = e.gen(ctx) + evaluationCode.code + + s""" + if(${evaluationCode.isNull}) + mutableRow.setNullAt($i); + else + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + """ } // collect projections into blocks as function has 64kb codesize limit in JVM val projectionBlocks = new ArrayBuffer[String]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 179a348d5baac..b8e3b0d53a505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -129,10 +129,10 @@ object PartialAggregation { case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => // Collect all aggregate expressions. val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) + aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) // Collect all aggregate expressions that can be computed partially. val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) + aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) // Only do partial aggregation if supported by all aggregate expressions. if (allAggregates.size == partialAggregates.size) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 986c315b3173a..6aefa9f67556a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 78c780bdc5797..1474b170ba896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -402,6 +402,9 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) + val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", + defaultValue = Some(true), doc = "") + val USE_SQL_SERIALIZER2 = booleanConf( "spark.sql.useSerializer2", defaultValue = Some(true), isPublic = false) @@ -473,6 +476,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED) + private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) + private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2) private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8b4528b5d52fe..49bfe74b680af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -285,6 +285,9 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient val udf: UDFRegistration = new UDFRegistration(this) + @transient + val udaf: UDAFRegistration = new UDAFRegistration(this) + /** * Returns true if the table is currently cached in-memory. * @group cachemgmt @@ -863,6 +866,7 @@ class SQLContext(@transient val sparkContext: SparkContext) DDLStrategy :: TakeOrderedAndProject :: HashAggregation :: + Aggregation :: LeftSemiJoin :: HashJoin :: InMemoryScans :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala new file mode 100644 index 0000000000000..5b872f5e3eecd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{Expression} +import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction} + +class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { + + private val functionRegistry = sqlContext.functionRegistry + + def register( + name: String, + func: UserDefinedAggregateFunction): UserDefinedAggregateFunction = { + def builder(children: Seq[Expression]) = ScalaUDAF(children, func) + functionRegistry.registerFunction(name, builder) + func + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 3cd60a2aa55ed..c2c945321db95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -68,14 +68,14 @@ case class Aggregate( * output. */ case class ComputedAggregate( - unbound: AggregateExpression, - aggregate: AggregateExpression, + unbound: AggregateExpression1, + aggregate: AggregateExpression1, resultAttribute: AttributeReference) /** A list of aggregates that need to be computed for each group. */ private[this] val computedAggregates = aggregateExpressions.flatMap { agg => agg.collect { - case a: AggregateExpression => + case a: AggregateExpression1 => ComputedAggregate( a, BindReferences.bindReference(a, child.output), @@ -87,8 +87,8 @@ case class Aggregate( private[this] val computedSchema = computedAggregates.map(_.resultAttribute) /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction] = { - val buffer = new Array[AggregateFunction](computedAggregates.length) + private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { + val buffer = new Array[AggregateFunction1](computedAggregates.length) var i = 0 while (i < computedAggregates.length) { buffer(i) = computedAggregates(i).aggregate.newInstance() @@ -146,7 +146,7 @@ case class Aggregate( } } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction]] + val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) var currentRow: InternalRow = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 2750053594f99..d31e265a293e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -247,8 +247,15 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } def addSortIfNecessary(child: SparkPlan): SparkPlan = { - if (rowOrdering.nonEmpty && child.outputOrdering != rowOrdering) { - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + + if (rowOrdering.nonEmpty) { + // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort. + val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min + if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) { + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + } else { + child + } } else { child } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index ecde9c57139a6..0e63f2fe29cb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -69,7 +69,7 @@ case class GeneratedAggregate( protected override def doExecute(): RDD[InternalRow] = { val aggregatesToCompute = aggregateExpressions.flatMap { a => - a.collect { case agg: AggregateExpression => agg} + a.collect { case agg: AggregateExpression1 => agg} } // If you add any new function support, please add tests in org.apache.spark.sql.SQLQuerySuite diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8cef7f200d2dc..f54aa2027f6a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -148,7 +149,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if canBeCodeGened( allAggregates(partialComputation) ++ allAggregates(rewrittenAggregateExpressions)) && - codegenEnabled => + codegenEnabled && + !canBeConvertedToNewAggregation(plan) => execution.GeneratedAggregate( partial = false, namedGroupingAttributes, @@ -167,7 +169,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { rewrittenAggregateExpressions, groupingExpressions, partialComputation, - child) => + child) if !canBeConvertedToNewAggregation(plan) => execution.Aggregate( partial = false, namedGroupingAttributes, @@ -181,7 +183,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { + def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = { + aggregate.Utils.tryConvert( + plan, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled).isDefined + } + + def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists { case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && @@ -189,10 +198,74 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => true } - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] = - exprs.flatMap(_.collect { case a: AggregateExpression => a }) + def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = + exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) } + /** + * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. + */ + object Aggregation extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case p: logical.Aggregate => + val converted = + aggregate.Utils.tryConvert( + p, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled) + converted match { + case None => Nil // Cannot convert to new aggregation code path. + case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => + // Extracts all distinct aggregate expressions from the resultExpressions. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + // For those distinct aggregate expressions, we create a map from the + // aggregate function to the corresponding attribute of the function. + val aggregateFunctionMap = aggregateExpressions.map { agg => + val aggregateFunction = agg.aggregateFunction + (aggregateFunction, agg.isDistinct) -> + Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + }.toMap + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets (aggregate.NewAggregation will not match). + sys.error( + "Multiple distinct column sets are not supported by the new aggregation" + + "code path.") + } + + val aggregateOperator = + if (functionsWithDistinct.isEmpty) { + aggregate.Utils.planAggregateWithoutDistinct( + groupingExpressions, + aggregateExpressions, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } else { + aggregate.Utils.planAggregateWithOneDistinct( + groupingExpressions, + functionsWithDistinct, + functionsWithoutDistinct, + aggregateFunctionMap, + resultExpressions, + planLater(child)) + } + + aggregateOperator + } + + case _ => Nil + } + } + + object BroadcastNestedLoopJoin extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => @@ -336,8 +409,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + case a @ logical.Aggregate(group, agg, child) => { + val useNewAggregation = + aggregate.Utils.tryConvert( + a, + sqlContext.conf.useSqlAggregate2, + sqlContext.conf.codegenEnabled).isDefined + if (useNewAggregation) { + // If this logical.Aggregate can be planned to use new aggregation code path + // (i.e. it can be planned by the Strategy Aggregation), we will not use the old + // aggregation code path. + Nil + } else { + execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + } + } case logical.Window(projectList, windowExpressions, spec, child) => execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala new file mode 100644 index 0000000000000..0c9082897f390 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} + +case class Aggregate2Sort( + requiredChildDistributionExpressions: Option[Seq[Expression]], + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def canProcessUnsafeRows: Boolean = true + + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + aggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + requiredChildDistributionExpressions match { + case Some(exprs) if exprs.length == 0 => AllTuples :: Nil + case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case None => UnspecifiedDistribution :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + // TODO: We should not sort the input rows if they are just in reversed order. + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + } + + override def outputOrdering: Seq[SortOrder] = { + // It is possible that the child.outputOrdering starts with the required + // ordering expressions (e.g. we require [a] as the sort expression and the + // child's outputOrdering is [a, b]). We can only guarantee the output rows + // are sorted by values of groupingExpressions. + groupingExpressions.map(SortOrder(_, Ascending)) + } + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + if (aggregateExpressions.length == 0) { + new GroupingIterator( + groupingExpressions, + resultExpressions, + newMutableProjection, + child.output, + iter) + } else { + val aggregationIterator: SortAggregationIterator = { + aggregateExpressions.map(_.mode).distinct.toList match { + case Partial :: Nil => + new PartialSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter) + case PartialMerge :: Nil => + new PartialMergeSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + child.output, + iter) + case Final :: Nil => + new FinalSortAggregationIterator( + groupingExpressions, + aggregateExpressions, + aggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + case other => + sys.error( + s"Could not evaluate ${aggregateExpressions} because we do not support evaluate " + + s"modes $other in this operator.") + } + } + + aggregationIterator + } + } + } +} + +case class FinalAndCompleteAggregate2Sort( + previousGroupingExpressions: Seq[NamedExpression], + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + override def references: AttributeSet = { + val referencesInResults = + AttributeSet(resultExpressions.flatMap(_.references)) -- + AttributeSet(finalAggregateExpressions) -- + AttributeSet(completeAggregateExpressions) + + AttributeSet( + groupingExpressions.flatMap(_.references) ++ + finalAggregateExpressions.flatMap(_.references) ++ + completeAggregateExpressions.flatMap(_.references) ++ + referencesInResults) + } + + override def requiredChildDistribution: List[Distribution] = { + if (groupingExpressions.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + groupingExpressions.map(SortOrder(_, Ascending)) :: Nil + + override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + + new FinalAndCompleteSortAggregationIterator( + previousGroupingExpressions.length, + groupingExpressions, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + resultExpressions, + newMutableProjection, + child.output, + iter) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala new file mode 100644 index 0000000000000..ce1cbdc9cb090 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -0,0 +1,749 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.NullType + +import scala.collection.mutable.ArrayBuffer + +/** + * An iterator used to evaluate aggregate functions. It assumes that input rows + * are already grouped by values of `groupingExpressions`. + */ +private[sql] abstract class SortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends Iterator[InternalRow] { + + /////////////////////////////////////////////////////////////////////////// + // Static fields for this iterator + /////////////////////////////////////////////////////////////////////////// + + protected val aggregateFunctions: Array[AggregateFunction2] = { + var bufferOffset = initialBufferOffset + val functions = new Array[AggregateFunction2](aggregateExpressions.length) + var i = 0 + while (i < aggregateExpressions.length) { + val func = aggregateExpressions(i).aggregateFunction + val funcWithBoundReferences = aggregateExpressions(i).mode match { + case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => + // We need to create BoundReferences if the function is not an + // AlgebraicAggregate (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, inputAttributes) + case _ => func + } + // Set bufferOffset for this function. It is important that setting bufferOffset + // happens after all potential bindReference operations because bindReference + // will create a new instance of the function. + funcWithBoundReferences.bufferOffset = bufferOffset + bufferOffset += funcWithBoundReferences.bufferSchema.length + functions(i) = funcWithBoundReferences + i += 1 + } + functions + } + + // All non-algebraic aggregate functions. + protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + aggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // Positions of those non-algebraic aggregate functions in aggregateFunctions. + // For example, we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are non-algebraic aggregate functions. + // nonAlgebraicAggregateFunctionPositions will be [1, 2]. + protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { + case agg: AlgebraicAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray + } + + // This is used to project expressions for the grouping expressions. + protected val groupGenerator = + newMutableProjection(groupingExpressions, inputAttributes)() + + // The underlying buffer shared by all aggregate functions. + protected val buffer: MutableRow = { + // The number of elements of the underlying buffer of this operator. + // All aggregate functions are sharing this underlying buffer and they find their + // buffer values through bufferOffset. + var size = initialBufferOffset + var i = 0 + while (i < aggregateFunctions.length) { + size += aggregateFunctions(i).bufferSchema.length + i += 1 + } + new GenericMutableRow(size) + } + + protected val joinedRow = new JoinedRow4 + + protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) + + // This projection is used to initialize buffer values for all AlgebraicAggregates. + protected val algebraicInitialProjection = { + val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.initialValues + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(initExpressions, Nil)().target(buffer) + } + + /////////////////////////////////////////////////////////////////////////// + // Mutable states + /////////////////////////////////////////////////////////////////////////// + + // The partition key of the current partition. + protected var currentGroupingKey: InternalRow = _ + // The partition key of next partition. + protected var nextGroupingKey: InternalRow = _ + // The first row of next partition. + protected var firstRowInNextGroup: InternalRow = _ + // Indicates if we has new group of rows to process. + protected var hasNewGroup: Boolean = true + + /////////////////////////////////////////////////////////////////////////// + // Private methods + /////////////////////////////////////////////////////////////////////////// + + /** Initializes buffer values for all aggregate functions. */ + protected def initializeBuffer(): Unit = { + algebraicInitialProjection(EmptyRow) + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).initialize(buffer) + i += 1 + } + } + + protected def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + + /** Processes rows in the current group. It will stop when it find a new group. */ + private def processCurrentGroup(): Unit = { + currentGroupingKey = nextGroupingKey + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(firstRowInNextGroup) + // The search will stop when we see the next group or there is no + // input row left in the iter. + while (inputIter.hasNext && !findNextPartition) { + val currentRow = inputIter.next() + // Get the grouping key based on the grouping expressions. + // For the below compare method, we do not need to make a copy of groupingKey. + val groupingKey = groupGenerator(currentRow) + // Check if the current row belongs the current input row. + currentGroupingKey.equals(groupingKey) + + if (currentGroupingKey == groupingKey) { + processRow(currentRow) + } else { + // We find a new group. + findNextPartition = true + nextGroupingKey = groupingKey.copy() + firstRowInNextGroup = currentRow.copy() + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the iter. + if (!findNextPartition) { + hasNewGroup = false + } + } + + /////////////////////////////////////////////////////////////////////////// + // Public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = hasNewGroup + + override final def next(): InternalRow = { + if (hasNext) { + // Process the current group. + processCurrentGroup() + // Generate output row for the current group. + val outputRow = generateOutput() + // Initilize buffer values for the next group. + initializeBuffer() + + outputRow + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Methods that need to be implemented + /////////////////////////////////////////////////////////////////////////// + + protected def initialBufferOffset: Int + + protected def processRow(row: InternalRow): Unit + + protected def generateOutput(): InternalRow + + /////////////////////////////////////////////////////////////////////////// + // Initialize this iterator + /////////////////////////////////////////////////////////////////////////// + + initialize() +} + +/** + * An iterator only used to group input rows according to values of `groupingExpressions`. + * It assumes that input rows are already grouped by values of `groupingExpressions`. + */ +class GroupingIterator( + groupingExpressions: Seq[NamedExpression], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + Nil, + newMutableProjection, + inputAttributes, + inputIter) { + + private val resultProjection = + newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))() + + override protected def initialBufferOffset: Int = 0 + + override protected def processRow(row: InternalRow): Unit = { + // Since we only do grouping, there is nothing to do at here. + } + + override protected def generateOutput(): InternalRow = { + resultProjection(currentGroupingKey) + } +} + +/** + * An iterator used to do partial aggregations (for those aggregate functions with mode Partial). + * It assumes that input rows are already grouped by values of `groupingExpressions`. + * The format of its output rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + */ +class PartialSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // This projection is used to update buffer values for all AlgebraicAggregates. + private val algebraicUpdateProjection = { + val bufferSchema = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } + val updateExpressions = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) + } + + override protected def initialBufferOffset: Int = 0 + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicUpdateProjection(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).update(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // We just output the grouping expressions and the underlying buffer. + joinedRow(currentGroupingKey, buffer).copy() + } +} + +/** + * An iterator used to do partial merge aggregations (for those aggregate functions with mode + * PartialMerge). It assumes that input rows are already grouped by values of + * `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + * + * The format of its internal buffer is: + * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN| + * Every placeholder is for a grouping expression. + * The actual buffers are stored after placeholderN. + * The reason that we have placeholders at here is to make our underlying buffer have the same + * length with a input row. + * + * The format of its output rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + */ +class PartialMergeSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + private val placeholderAttribtues = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = + placeholderAttribtues ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ placeholderAttribtues ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to extract aggregation buffers from the underlying buffer. + // We need it because the underlying buffer has placeholders at its beginning. + private val extractsBufferValues = { + val expressions = aggregateFunctions.flatMap { + case agg => agg.bufferAttributes + } + + newMutableProjection(expressions, inputAttributes)() + } + + override protected def initialBufferOffset: Int = groupingExpressions.length + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // We output grouping expressions and aggregation buffers. + joinedRow(currentGroupingKey, extractsBufferValues(buffer)) + } +} + +/** + * An iterator used to do final aggregations (for those aggregate functions with mode + * Final). It assumes that input rows are already grouped by values of + * `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + * + * The format of its internal buffer is: + * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN| + * Every placeholder is for a grouping expression. + * The actual buffers are stored after placeholderN. + * The reason that we have placeholders at here is to make our underlying buffer have the same + * length with a input row. + * + * The format of its output rows is represented by the schema of `resultExpressions`. + */ +class FinalSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // The result of aggregate functions. + private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length) + + // The projection used to generate the output rows of this operator. + // This is only used when we are generating final results of aggregate functions. + private val resultProjection = + newMutableProjection( + resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() + + private val offsetAttributes = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + override protected def initialBufferOffset: Int = groupingExpressions.length + + override def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + if (groupingExpressions.isEmpty) { + // If there is no grouping expression, we need to generate a single row as the output. + initializeBuffer() + // Right now, the buffer only contains initial buffer values. Because + // merging two buffers with initial values will generate a row that + // still store initial values. We set the currentRow as the copy of the current buffer. + val currentRow = buffer.copy() + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + } + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(buffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(buffer)) + i += 1 + } + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } +} + +/** + * An iterator used to do both final aggregations (for those aggregate functions with mode + * Final) and complete aggregations (for those aggregate functions with mode Complete). + * It assumes that input rows are already grouped by values of `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN| + * col1 to colM are columns used by aggregate functions with Complete mode. + * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with + * Final mode. + * + * The format of its internal buffer is: + * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)| + * The first N placeholders represent slots of grouping expressions. + * Then, next M placeholders represent slots of col1 to colM. + * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with + * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode + * Complete. The reason that we have placeholders at here is to make our underlying buffer + * have the same length with a input row. + * + * The format of its output rows is represented by the schema of `resultExpressions`. + */ +class FinalAndCompleteSortAggregationIterator( + override protected val initialBufferOffset: Int, + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + // TODO: document the ordering + finalAggregateExpressions ++ completeAggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // The result of aggregate functions. + private val aggregateResult: MutableRow = + new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length) + + // The projection used to generate the output rows of this operator. + // This is only used when we are generating final results of aggregate functions. + private val resultProjection = { + val inputSchema = + groupingExpressions.map(_.toAttribute) ++ + finalAggregateAttributes ++ + completeAggregateAttributes + newMutableProjection(resultExpressions, inputSchema)() + } + + private val offsetAttributes = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // All aggregate functions with mode Final. + private val finalAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) + var i = 0 + while (i < finalAggregateExpressions.length) { + functions(i) = aggregateFunctions(i) + i += 1 + } + functions + } + + // All non-algebraic aggregate functions with mode Final. + private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + finalAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // All aggregate functions with mode Complete. + private val completeAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](completeAggregateExpressions.length) + var i = 0 + while (i < completeAggregateExpressions.length) { + functions(i) = aggregateFunctions(finalAggregateFunctions.length + i) + i += 1 + } + functions + } + + // All non-algebraic aggregate functions with mode Complete. + private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + completeAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // This projection is used to merge buffer values for all AlgebraicAggregates with mode + // Final. + private val finalAlgebraicMergeProjection = { + val numCompleteOffsetAttributes = + completeAggregateFunctions.map(_.bufferAttributes.length).sum + val completeOffsetAttributes = + Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)()) + val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp) + + val bufferSchemata = + offsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } ++ completeOffsetAttributes + val mergeExpressions = + placeholderExpressions ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } ++ completeOffsetExpressions + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to update buffer values for all AlgebraicAggregates with mode + // Complete. + private val completeAlgebraicUpdateProjection = { + val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum + val finalOffsetAttributes = + Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)()) + val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp) + + val bufferSchema = + offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } + val updateExpressions = + placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + override def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + if (groupingExpressions.isEmpty) { + // If there is no grouping expression, we need to generate a single row as the output. + initializeBuffer() + // Right now, the buffer only contains initial buffer values. Because + // merging two buffers with initial values will generate a row that + // still store initial values. We set the currentRow as the copy of the current buffer. + val currentRow = buffer.copy() + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + } + + override protected def processRow(row: InternalRow): Unit = { + val input = joinedRow(buffer, row) + // For all aggregate functions with mode Complete, update buffers. + completeAlgebraicUpdateProjection(input) + var i = 0 + while (i < completeNonAlgebraicAggregateFunctions.length) { + completeNonAlgebraicAggregateFunctions(i).update(buffer, row) + i += 1 + } + + // For all aggregate functions with mode Final, merge buffers. + finalAlgebraicMergeProjection.target(buffer)(input) + i = 0 + while (i < finalNonAlgebraicAggregateFunctions.length) { + finalNonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(buffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(buffer)) + i += 1 + } + + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala new file mode 100644 index 0000000000000..1cb27710e0480 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{StructType, MapType, ArrayType} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object Utils { + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { + val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { + case array: ArrayType => true + case map: MapType => true + case struct: StructType => true + case _ => false + } + + !hasComplexTypes + } + + private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = p.transformExpressionsDown { + case expressions.Average(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Average(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Count(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(child), + mode = aggregate.Complete, + isDistinct = false) + + // We do not support multiple COUNT DISTINCT columns for now. + case expressions.CountDistinct(children) if children.length == 1 => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(children.head), + mode = aggregate.Complete, + isDistinct = true) + + case expressions.First(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.First(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Last(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Last(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Max(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Max(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Min(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Min(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Sum(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.SumDistinct(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = true) + } + // Check if there is any expressions.AggregateExpression1 left. + // If so, we cannot convert this plan. + val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => + // For every expressions, check if it contains AggregateExpression1. + expr.find { + case agg: expressions.AggregateExpression1 => true + case other => false + }.isDefined + } + + // Check if there are multiple distinct columns. + val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val hasMultipleDistinctColumnSets = + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + true + } else { + false + } + + if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None + + case other => None + } + + private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { + // If the plan cannot be converted, we will do a final round check to if the original + // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, + // we need to throw an exception. + val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg.aggregateFunction + } + }.distinct + if (aggregateFunction2s.nonEmpty) { + // For functions implemented based on the new interface, prepare a list of function names. + val invalidFunctions = { + if (aggregateFunction2s.length > 1) { + s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + + s"and ${aggregateFunction2s.head.nodeName} are" + } else { + s"${aggregateFunction2s.head.nodeName} is" + } + } + val errorMessage = + s"${invalidFunctions} implemented based on the new Aggregate Function " + + s"interface and it cannot be used with functions implemented based on " + + s"the old Aggregate Function interface." + throw new AnalysisException(errorMessage) + } + } + + def tryConvert( + plan: LogicalPlan, + useNewAggregation: Boolean, + codeGenEnabled: Boolean): Option[Aggregate] = plan match { + case p: Aggregate if useNewAggregation && codeGenEnabled => + val converted = tryConvert(p) + if (converted.isDefined) { + converted + } else { + checkInvalidAggregateFunction2(p) + None + } + case p: Aggregate => + checkInvalidAggregateFunction2(p) + None + case other => None + } + + def planAggregateWithoutDistinct( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + // 1. Create an Aggregate Operator for partial aggregations. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + val partialAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Partial, isDistinct) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + namedGroupingExpressions.map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for final aggregations. + val finalAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Final, isDistinct) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val finalAggregate = Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + rewrittenResultExpressions, + partialAggregate) + + finalAggregate :: Nil + } + + def planAggregateWithOneDistinct( + groupingExpressions: Seq[Expression], + functionsWithDistinct: Seq[AggregateExpression2], + functionsWithoutDistinct: Seq[AggregateExpression2], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + // 1. Create an Aggregate Operator for partial aggregations. + // The grouping expressions are original groupingExpressions and + // distinct columns. For example, for avg(distinct value) ... group by key + // the grouping expressions of this Aggregate Operator will be [key, value]. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + // It is safe to call head at here since functionsWithDistinct has at least one + // AggregateExpression2. + val distinctColumnExpressions = + functionsWithDistinct.head.aggregateFunction.children + val namedDistinctColumnExpressions = distinctColumnExpressions.map { + case ne: NamedExpression => ne -> ne + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap + val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) + + val partialAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Partial, false) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for partial merge aggregations. + val partialMergeAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, PartialMerge, false) + } + val partialMergeAggregateAttributes = + partialMergeAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val partialMergeAggregate = + Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes ++ distinctColumnAttributes, + partialMergeAggregateExpressions, + partialMergeAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes, + partialAggregate) + + // 3. Create an Aggregate Operator for partial merge aggregations. + val finalAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Final, false) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) => + val rewrittenAggregateFunction = aggregateFunction.transformDown { + case expr if distinctColumnExpressionMap.contains(expr) => + distinctColumnExpressionMap(expr).toAttribute + }.asInstanceOf[AggregateFunction2] + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + val rewrittenAggregateExpression = + AggregateExpression2(rewrittenAggregateFunction, Complete, false) + + val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct) + (rewrittenAggregateExpression -> aggregateFunctionAttribute) + }.unzip + + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort( + namedGroupingAttributes ++ distinctColumnAttributes, + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + rewrittenResultExpressions, + partialMergeAggregate) + + finalAndCompleteAggregate :: Nil + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala new file mode 100644 index 0000000000000..6c49a906c848a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.aggregate + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row + +/** + * The abstract class for implementing user-defined aggregate function. + */ +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def returnDataType: DataType + + /** Indicates if this function is deterministic. */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer. Initial values set by this method should satisfy + * the condition that when merging two buffers with initial values, the new buffer should + * still store initial values. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any +} + +private[sql] abstract class AggregationBuffer( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int) + extends Row { + + override def length: Int = toCatalystConverters.length + + protected val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } +} + +/** + * A Mutable [[Row]] representing an mutable aggregation buffer. + */ +class MutableAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingBuffer: MutableRow) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingBuffer(offsets(i))) + } + + def update(i: Int, value: Any): Unit = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not update ${i}th value in this buffer because it only has $length values.") + } + underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) + } + + override def copy(): MutableAggregationBuffer = { + new MutableAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingBuffer) + } +} + +/** + * A [[Row]] representing an immutable aggregation buffer. + */ +class InputAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingInputBuffer: Row) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingInputBuffer(offsets(i))) + } + + override def copy(): InputAggregationBuffer = { + new InputAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingInputBuffer) + } +} + +/** + * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the + * internal aggregation code path. + * @param children + * @param udaf + */ +case class ScalaUDAF( + children: Seq[Expression], + udaf: UserDefinedAggregateFunction) + extends AggregateFunction2 with Logging { + + require( + children.length == udaf.inputSchema.length, + s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + + s"but ${children.length} are provided.") + + override def nullable: Boolean = true + + override def dataType: DataType = udaf.returnDataType + + override def deterministic: Boolean = udaf.deterministic + + override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) + + override val bufferSchema: StructType = udaf.bufferSchema + + override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + val childrenSchema: StructType = { + val inputFields = children.zipWithIndex.map { + case (child, index) => + StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) + } + StructType(inputFields) + } + + lazy val inputProjection = { + val inputAttributes = childrenSchema.toAttributes + log.debug( + s"Creating MutableProj: $children, inputSchema: $inputAttributes.") + try { + GenerateMutableProjection.generate(children, inputAttributes)() + } catch { + case e: Exception => + log.error("Failed to generate mutable projection, fallback to interpreted", e) + new InterpretedMutableProjection(children, inputAttributes) + } + } + + val inputToScalaConverters: Any => Any = + CatalystTypeConverters.createToScalaConverter(childrenSchema) + + val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToCatalystConverter(field.dataType) + } + + val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToScalaConverter(field.dataType) + } + + lazy val inputAggregateBuffer: InputAggregationBuffer = + new InputAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + lazy val mutableAggregateBuffer: MutableAggregationBuffer = + new MutableAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + + override def initialize(buffer: MutableRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.initialize(mutableAggregateBuffer) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.update( + mutableAggregateBuffer, + inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer1 + inputAggregateBuffer.underlyingInputBuffer = buffer2 + + udaf.merge(mutableAggregateBuffer, inputAggregateBuffer) + } + + override def eval(buffer: InternalRow = null): Any = { + inputAggregateBuffer.underlyingInputBuffer = buffer + + udaf.evaluate(inputAggregateBuffer) + } + + override def toString: String = { + s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = udaf.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 28159cbd5ab96..bfeecbe8b2ab5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2420,7 +2420,7 @@ object functions { * @since 1.5.0 */ def callUDF(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } /** @@ -2449,7 +2449,7 @@ object functions { exprs(i) = cols(i).expr i += 1 } - UnresolvedFunction(udfName, exprs) + UnresolvedFunction(udfName, exprs, isDistinct = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index beee10173fbc4..ab8dce603c117 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.execution.aggregate.Aggregate2Sort import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { var hasGeneratedAgg = false df.queryExecution.executedPlan.foreach { case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case newAggregate: Aggregate2Sort => hasGeneratedAgg = true case _ => } if (!hasGeneratedAgg) { @@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Aggregate with Code generation handling all null values testCodeGen( "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) + Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3dd24130af81a..3d71deb13e884 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext._ @@ -30,6 +31,20 @@ import org.apache.spark.sql.{Row, SQLConf, execution} class PlannerSuite extends SparkFunSuite { + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val planned = + plannedOption.getOrElse( + fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } + + // For the new aggregation code path, there will be three aggregate operator for + // distinct aggregations. + assert( + aggregations.size == 2 || aggregations.size == 3, + s"The plan of query $query does not have partial aggregations.") + } + test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head @@ -42,23 +57,18 @@ class PlannerSuite extends SparkFunSuite { test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - val planned = HashAggregation(query).head - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - - assert(aggregations.size === 2) + testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 31a49a3683338..24a758f53170a 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -833,6 +833,7 @@ abstract class HiveWindowFunctionQueryFileBaseSuite "windowing_adjust_rowcontainer_sz" ) + // Only run those query tests in the realWhileList (do not try other ignored query files). override def testCases: Seq[(String, File)] = super.testCases.filter { case (name, _) => realWhiteList.contains(name) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index f458567e5d7ea..1fe4fe9629c02 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.io.File + import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive @@ -159,4 +161,9 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { "join_reorder4", "join_star" ) + + // Only run those query tests in the realWhileList (do not try other ignored query files). + override def testCases: Seq[(String, File)] = super.testCases.filter { + case (name, _) => realWhiteList.contains(name) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cec7685bb6859..4cdb83c5116f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -451,6 +451,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { DataSinks, Scripts, HashAggregation, + Aggregation, LeftSemiJoin, HashJoin, BasicOperators, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f5574509b0b38..8518e333e8058 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1464,9 +1464,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr)) + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) /* Literals */ case Token("TOK_NULL", Nil) => Literal.create(null, NullType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4d23c7035c03d..3259b50acc765 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -409,7 +409,7 @@ private[hive] case class HiveWindowFunction( private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = AbstractGenericUDAFResolver @@ -441,7 +441,7 @@ private[hive] case class HiveGenericUDAF( /** It is used as a wrapper for the hive functions which uses UDAF interface */ private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = UDAF @@ -550,9 +550,9 @@ private[hive] case class HiveGenericUDTF( private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], - base: AggregateExpression, + base: AggregateExpression1, isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction + extends AggregateFunction1 with HiveInspectors { def this() = this(null, null, null) diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java new file mode 100644 index 0000000000000..5c9d0e97a99c6 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class MyDoubleAvg extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleAvg() { + List inputfields = new ArrayList(); + inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputfields); + + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); + bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType returnDataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, null); + buffer.update(1, 0L); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + buffer.update(1, 1L); + } else { + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + buffer.update(1, buffer.getLong(1) + 1L); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + buffer1.update(0, buffer2.getDouble(0)); + buffer1.update(1, buffer2.getLong(1)); + } else { + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + return null; + } else { + return buffer.getDouble(0) / buffer.getLong(1) + 100.0; + } + } +} + diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java new file mode 100644 index 0000000000000..1d4587a27c787 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.Row; + +public class MyDoubleSum extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleSum() { + List inputfields = new ArrayList(); + inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputfields); + + List bufferFields = new ArrayList(); + bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType returnDataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, null); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + } else { + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + buffer1.update(0, buffer2.getDouble(0)); + } else { + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + return null; + } else { + return buffer.getDouble(0); + } + } +} diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 new file mode 100644 index 0000000000000..44b2a42cc26c5 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 @@ -0,0 +1 @@ +unhex(str) - Converts hexadecimal argument to binary diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 new file mode 100644 index 0000000000000..97af3b812a429 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 @@ -0,0 +1,14 @@ +unhex(str) - Converts hexadecimal argument to binary +Performs the inverse operation of HEX(str). That is, it interprets +each pair of hexadecimal digits in the argument as a number and +converts it to the byte representation of the number. The +resulting characters are returned as a binary string. + +Example: +> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1; +'MySQL' + +The characters in the argument string must be legal hexadecimal +digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters +any nonhexadecimal digits in the argument, it returns NULL. Also, +if there are an odd number of characters a leading 0 is appended. diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e new file mode 100644 index 0000000000000..b4a6f2b692227 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e @@ -0,0 +1 @@ +MySQL 1267 a -4 diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 new file mode 100644 index 0000000000000..3a67adaf0a9a8 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 @@ -0,0 +1 @@ +NULL NULL NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala new file mode 100644 index 0000000000000..0375eb79add95 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.execution.aggregate.Aggregate2Sort +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.scalatest.BeforeAndAfterAll +import test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} + +class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll { + + override val sqlContext = TestHive + import sqlContext.implicits._ + + var originalUseAggregate2: Boolean = _ + + override def beforeAll(): Unit = { + originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 + sqlContext.sql("set spark.sql.useAggregate2=true") + val data1 = Seq[(Integer, Integer)]( + (1, 10), + (null, -60), + (1, 20), + (1, 30), + (2, 0), + (null, -10), + (2, -1), + (2, null), + (2, null), + (null, 100), + (3, null), + (null, null), + (3, null)).toDF("key", "value") + data1.write.saveAsTable("agg1") + + val data2 = Seq[(Integer, Integer, Integer)]( + (1, 10, -10), + (null, -60, 60), + (1, 30, -30), + (1, 30, 30), + (2, 1, 1), + (null, -10, 10), + (2, -1, null), + (2, 1, 1), + (2, null, 1), + (null, 100, -10), + (3, null, 3), + (null, null, null), + (3, null, null)).toDF("key", "value1", "value2") + data2.write.saveAsTable("agg2") + + val emptyDF = sqlContext.createDataFrame( + sqlContext.sparkContext.emptyRDD[Row], + StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) + emptyDF.registerTempTable("emptyTable") + + // Register UDAFs + sqlContext.udaf.register("mydoublesum", new MyDoubleSum) + sqlContext.udaf.register("mydoubleavg", new MyDoubleAvg) + } + + override def afterAll(): Unit = { + sqlContext.sql("DROP TABLE IF EXISTS agg1") + sqlContext.sql("DROP TABLE IF EXISTS agg2") + sqlContext.dropTempTable("emptyTable") + sqlContext.sql(s"set spark.sql.useAggregate2=$originalUseAggregate2") + } + + test("empty table") { + // If there is no GROUP BY clause and the table is empty, we will generate a single row. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(key), + | COUNT(value), + | FIRST(key), + | LAST(value), + | MAX(key), + | MIN(value), + | SUM(key), + | COUNT(DISTINCT value) + |FROM emptyTable + """.stripMargin), + Row(null, 0, 0, 0, null, null, null, null, null, 0) :: Nil) + + // If there is a GROUP BY clause and the table is empty, there is no output. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | AVG(value), + | COUNT(*), + | COUNT(value), + | FIRST(value), + | LAST(value), + | MAX(value), + | MIN(value), + | SUM(value), + | COUNT(DISTINCT value) + |FROM emptyTable + |GROUP BY key + """.stripMargin), + Nil) + } + + test("only do grouping") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT DISTINCT value1, key + |FROM agg2 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT value1, key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(10, 1) :: + Row(-60, null) :: + Row(30, 1) :: + Row(1, 2) :: + Row(-10, null) :: + Row(-1, 2) :: + Row(null, 2) :: + Row(100, null) :: + Row(null, 3) :: + Row(null, null) :: Nil) + } + + test("case in-sensitive resolution") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), kEY - 100 + |FROM agg1 + |GROUP BY Key - 100 + """.stripMargin), + Row(20.0, -99) :: Row(-0.5, -98) :: Row(null, -97) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT sum(distinct value1), kEY - 100, count(distinct value1) + |FROM agg2 + |GROUP BY Key - 100 + """.stripMargin), + Row(40, -99, 2) :: Row(0, -98, 2) :: Row(null, -97, 0) :: Row(30, null, 3) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT valUe * key - 100 + |FROM agg1 + |GROUP BY vAlue * keY - 100 + """.stripMargin), + Row(-90) :: + Row(-80) :: + Row(-70) :: + Row(-100) :: + Row(-102) :: + Row(null) :: Nil) + } + + test("test average no key in output") { + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(-0.5) :: Row(20.0) :: Row(null) :: Row(10.0) :: Nil) + } + + test("test average") { + checkAnswer( + sqlContext.sql( + """ + |SELECT key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 20.0) :: Row(2, -0.5) :: Row(3, null) :: Row(null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(20.0, 1) :: Row(-0.5, 2) :: Row(null, 3) :: Row(10.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) + 1.5, key + 10 + |FROM agg1 + |GROUP BY key + 10 + """.stripMargin), + Row(21.5, 11) :: Row(1.0, 12) :: Row(null, 13) :: Row(11.5, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(value) FROM agg1 + """.stripMargin), + Row(11.125) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT avg(null) + """.stripMargin), + Row(null) :: Nil) + } + + test("udaf") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | mydoubleavg(value), + | avg(value - key), + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(1, 64.5, 120.0, 19.0, 55.5, 20.0) :: + Row(2, 5.0, 99.5, -2.5, -7.0, -0.5) :: + Row(3, null, null, null, null, null) :: + Row(null, null, 110.0, null, null, 10.0) :: Nil) + } + + test("non-AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1) :: Row(-1.0, 2) :: Row(null, 3) :: Row(30.0, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value) FROM agg1 + """.stripMargin), + Row(89.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(null) + """.stripMargin), + Row(null) :: Nil) + } + + test("non-AlgebraicAggregate and AlgebraicAggregate aggreguate function") { + checkAnswer( + sqlContext.sql( + """ + |SELECT mydoublesum(value), key, avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(60.0, 1, 20.0) :: + Row(-1.0, 2, -0.5) :: + Row(null, 3, null) :: + Row(30.0, null, 10.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoublesum(value + 1.5 * key), + | avg(value - key), + | key, + | mydoublesum(value - 1.5 * key), + | avg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin), + Row(64.5, 19.0, 1, 55.5, 20.0) :: + Row(5.0, -2.5, 2, -7.0, -0.5) :: + Row(null, null, 3, null, null) :: + Row(null, null, null, null, 10.0) :: Nil) + } + + test("single distinct column set") { + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100.0)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + } + + test("test count") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1) :: + Row(1, -60, 1, 1, null) :: + Row(2, 30, 2, 2, 1) :: + Row(2, 1, 2, 2, 2) :: + Row(1, -10, 1, 1, null) :: + Row(0, -1, 1, 1, 2) :: + Row(1, null, 1, 1, 2) :: + Row(1, 100, 1, 1, null) :: + Row(1, null, 2, 2, 3) :: + Row(0, null, 1, 1, null) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value2), + | value1, + | count(*), + | count(1), + | key, + | count(DISTINCT abs(value2)) + |FROM agg2 + |GROUP BY key, value1 + """.stripMargin), + Row(1, 10, 1, 1, 1, 1) :: + Row(1, -60, 1, 1, null, 1) :: + Row(2, 30, 2, 2, 1, 1) :: + Row(2, 1, 2, 2, 2, 1) :: + Row(1, -10, 1, 1, null, 1) :: + Row(0, -1, 1, 1, 2, 0) :: + Row(1, null, 1, 1, 2, 1) :: + Row(1, 100, 1, 1, null, 1) :: + Row(1, null, 2, 2, 3, 1) :: + Row(0, null, 1, 1, null, 0) :: Nil) + } + + test("error handling") { + sqlContext.sql(s"set spark.sql.useAggregate2=false") + var errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | mydoublesum(value), + | mydoubleavg(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + + // TODO: once we support Hive UDAF in the new interface, + // we can remove the following two tests. + sqlContext.sql(s"set spark.sql.useAggregate2=true") + errorMessage = intercept[AnalysisException] { + sqlContext.sql( + """ + |SELECT + | key, + | mydoublesum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).collect() + }.getMessage + assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) + + // This will fall back to the old aggregate + val newAggregateOperators = sqlContext.sql( + """ + |SELECT + | key, + | sum(value + 1.5 * key), + | stddev_samp(value) + |FROM agg1 + |GROUP BY key + """.stripMargin).queryExecution.executedPlan.collect { + case agg: Aggregate2Sort => agg + } + val message = + "We should fallback to the old aggregation code path if there is any aggregate function " + + "that cannot be converted to the new interface." + assert(newAggregateOperators.isEmpty, message) + + sqlContext.sql(s"set spark.sql.useAggregate2=true") + } +} From b55a36bc30a628d76baa721d38789fc219eccc27 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 22 Jul 2015 09:32:42 -0700 Subject: [PATCH 005/219] [SPARK-9254] [BUILD] [HOTFIX] sbt-launch-lib.bash should support HTTP/HTTPS redirection Target file(s) can be hosted on CDN nodes. HTTP/HTTPS redirection must be supported to download these files. Author: Cheng Lian Closes #7597 from liancheng/spark-9254 and squashes the following commits: fd266ca [Cheng Lian] Uses `--fail' to make curl return non-zero value and remove garbage output when the download fails a7cbfb3 [Cheng Lian] Supports HTTP/HTTPS redirection --- build/sbt-launch-lib.bash | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 504be48b358fa..7930a38b9674a 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -51,9 +51,13 @@ acquire_sbt_jar () { printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" if [ $(command -v curl) ]; then - (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + (curl --fail --location --silent ${URL1} > "${JAR_DL}" ||\ + (rm -f "${JAR_DL}" && curl --fail --location --silent ${URL2} > "${JAR_DL}")) &&\ + mv "${JAR_DL}" "${JAR}" elif [ $(command -v wget) ]; then - (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" + (wget --quiet ${URL1} -O "${JAR_DL}" ||\ + (rm -f "${JAR_DL}" && wget --quiet ${URL2} -O "${JAR_DL}")) &&\ + mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" exit -1 From 76520955fddbda87a5c53d0a394dedc91dce67e8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 22 Jul 2015 11:45:51 -0700 Subject: [PATCH 006/219] [SPARK-9082] [SQL] Filter using non-deterministic expressions should not be pushed down Author: Wenchen Fan Closes #7446 from cloud-fan/filter and squashes the following commits: 330021e [Wenchen Fan] add exists to tree node 2cab68c [Wenchen Fan] more enhance 949be07 [Wenchen Fan] push down part of predicate if possible 3912f84 [Wenchen Fan] address comments 8ce15ca [Wenchen Fan] fix bug 557158e [Wenchen Fan] Filter using non-deterministic expressions should not be pushed down --- .../sql/catalyst/optimizer/Optimizer.scala | 50 +++++++++++++++---- .../optimizer/FilterPushdownSuite.scala | 45 ++++++++++++++++- 2 files changed, 84 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e42f0b9a247e3..d2db3dd3d078e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -541,20 +541,50 @@ object SimplifyFilters extends Rule[LogicalPlan] { * * This heuristic is valid assuming the expression evaluation cost is minimal. */ -object PushPredicateThroughProject extends Rule[LogicalPlan] { +object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, project @ Project(fields, grandChild)) => - val sourceAliases = fields.collect { case a @ Alias(c, _) => - (a.toAttribute: Attribute) -> c - }.toMap - project.copy(child = filter.copy( - replaceAlias(condition, sourceAliases), - grandChild)) + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b). + val aliasMap = AttributeMap(fields.collect { + case a: Alias => (a.toAttribute, a.child) + }) + + // Split the condition into small conditions by `And`, so that we can push down part of this + // condition without nondeterministic expressions. + val andConditions = splitConjunctivePredicates(condition) + val nondeterministicConditions = andConditions.filter(hasNondeterministic(_, aliasMap)) + + // If there is no nondeterministic conditions, push down the whole condition. + if (nondeterministicConditions.isEmpty) { + project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) + } else { + // If they are all nondeterministic conditions, leave it un-changed. + if (nondeterministicConditions.length == andConditions.length) { + filter + } else { + val deterministicConditions = andConditions.filterNot(hasNondeterministic(_, aliasMap)) + // Push down the small conditions without nondeterministic expressions. + val pushedCondition = deterministicConditions.map(replaceAlias(_, aliasMap)).reduce(And) + Filter(nondeterministicConditions.reduce(And), + project.copy(child = Filter(pushedCondition, grandChild))) + } + } + } + + private def hasNondeterministic( + condition: Expression, + sourceAliases: AttributeMap[Expression]) = { + condition.collect { + case a: Attribute if sourceAliases.contains(a) => sourceAliases(a) + }.exists(!_.deterministic) } - private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = { - condition transform { - case a: AttributeReference => sourceAliases.getOrElse(a, a) + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { + condition.transform { + case a: Attribute => sourceAliases.getOrElse(a, a) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index dc28b3ffb59ee..0f1fde2fb0f67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.expressions.{SortOrder, Ascending, Count, Explode} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ @@ -146,6 +146,49 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("nondeterministic: can't push down filter through project") { + val originalQuery = testRelation + .select(Rand(10).as('rand), 'a) + .where('rand > 5 || 'a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + + test("nondeterministic: push down part of filter through project") { + val originalQuery = testRelation + .select(Rand(10).as('rand), 'a) + .where('rand > 5 && 'a > 5) + .analyze + + val optimized = Optimize.execute(originalQuery) + + val correctAnswer = testRelation + .where('a > 5) + .select(Rand(10).as('rand), 'a) + .where('rand > 5) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("nondeterministic: push down filter through project") { + val originalQuery = testRelation + .select(Rand(10).as('rand), 'a) + .where('a > 5 && 'a < 10) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = testRelation + .where('a > 5 && 'a < 10) + .select(Rand(10).as('rand), 'a) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("filters: combines filters") { val originalQuery = testRelation .select('a) From 86f80e2b4759e574fe3eb91695f81b644db87242 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 22 Jul 2015 12:19:59 -0700 Subject: [PATCH 007/219] [SPARK-9165] [SQL] codegen for CreateArray, CreateStruct and CreateNamedStruct JIRA: https://issues.apache.org/jira/browse/SPARK-9165 Author: Yijie Shen Closes #7537 from yjshen/array_struct_codegen and squashes the following commits: 3a6dce6 [Yijie Shen] use infix notion in createArray test 5e90f0a [Yijie Shen] resolve comments: classOf 39cefb8 [Yijie Shen] codegen for createArray createStruct & createNamedStruct --- .../expressions/complexTypeCreator.scala | 65 +++++++++++++++++-- .../expressions/ComplexTypeSuite.scala | 16 +++++ 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index f9fd04c02aaef..20b1eaab8e303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** * Returns an Array containing the evaluation of all children expressions. */ -case class CreateArray(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -45,14 +47,31 @@ case class CreateArray(children: Seq[Expression]) extends Expression with Codege children.map(_.eval(input)) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + s""" + boolean ${ev.isNull} = false; + $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "array" } /** * Returns a Row containing the evaluation of all children expressions. - * TODO: [[CreateStruct]] does not support codegen. */ -case class CreateStruct(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -76,6 +95,24 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg InternalRow(children.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${children.size}); + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "struct" } @@ -84,7 +121,7 @@ case class CreateStruct(children: Seq[Expression]) extends Expression with Codeg * * @param children Seq(name1, val1, name2, val2, ...) */ -case class CreateNamedStruct(children: Seq[Expression]) extends Expression with CodegenFallback { +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -122,5 +159,23 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression with InternalRow(valExprs.map(_.eval(input)): _*) } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericMutableRow].getName + s""" + boolean ${ev.isNull} = false; + final $rowClass ${ev.primitive} = new $rowClass(${valExprs.size}); + """ + + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ + }.mkString("\n") + } + override def prettyName: String = "named_struct" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index e3042143632aa..a8aee8f634e03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -117,6 +117,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) } + test("CreateArray") { + val intSeq = Seq(5, 10, 15, 20, 25) + val longSeq = intSeq.map(_.toLong) + val strSeq = intSeq.map(_.toString) + checkEvaluation(CreateArray(intSeq.map(Literal(_))), intSeq, EmptyRow) + checkEvaluation(CreateArray(longSeq.map(Literal(_))), longSeq, EmptyRow) + checkEvaluation(CreateArray(strSeq.map(Literal(_))), strSeq, EmptyRow) + + val intWithNull = intSeq.map(Literal(_)) :+ Literal.create(null, IntegerType) + val longWithNull = longSeq.map(Literal(_)) :+ Literal.create(null, LongType) + val strWithNull = strSeq.map(Literal(_)) :+ Literal.create(null, StringType) + checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + } + test("CreateStruct") { val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) From e0b7ba59a1ace9b78a1ad6f3f07fe153db20b52c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 22 Jul 2015 13:02:43 -0700 Subject: [PATCH 008/219] [SPARK-9024] Unsafe HashJoin/HashOuterJoin/HashSemiJoin This PR introduce unsafe version (using UnsafeRow) of HashJoin, HashOuterJoin and HashSemiJoin, including the broadcast one and shuffle one (except FullOuterJoin, which is better to be implemented using SortMergeJoin). It use HashMap to store UnsafeRow right now, will change to use BytesToBytesMap for better performance (in another PR). Author: Davies Liu Closes #7480 from davies/unsafe_join and squashes the following commits: 6294b1e [Davies Liu] fix projection 10583f1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join dede020 [Davies Liu] fix test 84c9807 [Davies Liu] address comments a05b4f6 [Davies Liu] support UnsafeRow in LeftSemiJoinBNL and BroadcastNestedLoopJoin 611d2ed [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 9481ae8 [Davies Liu] return UnsafeRow after join() ca2b40f [Davies Liu] revert unrelated change 68f5cd9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 0f4380d [Davies Liu] ada a comment 69e38f5 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 1a40f02 [Davies Liu] refactor ab1690f [Davies Liu] address comments 60371f2 [Davies Liu] use UnsafeRow in SemiJoin a6c0b7d [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_join 184b852 [Davies Liu] fix style 6acbb11 [Davies Liu] fix tests 95d0762 [Davies Liu] remove println bea4a50 [Davies Liu] Unsafe HashJoin --- .../sql/catalyst/expressions/UnsafeRow.java | 50 ++++++++++- .../execution/UnsafeExternalRowSorter.java | 10 +-- .../catalyst/expressions/BoundAttribute.scala | 19 ++++- .../sql/catalyst/expressions/Projection.scala | 34 +++++++- .../execution/joins/BroadcastHashJoin.scala | 2 +- .../joins/BroadcastHashOuterJoin.scala | 32 ++----- .../joins/BroadcastLeftSemiJoinHash.scala | 5 +- .../joins/BroadcastNestedLoopJoin.scala | 37 +++++--- .../spark/sql/execution/joins/HashJoin.scala | 43 ++++++++-- .../sql/execution/joins/HashOuterJoin.scala | 82 +++++++++++++++--- .../sql/execution/joins/HashSemiJoin.scala | 74 ++++++++++------ .../sql/execution/joins/HashedRelation.scala | 85 ++++++++++++++++++- .../sql/execution/joins/LeftSemiJoinBNL.scala | 3 + .../execution/joins/LeftSemiJoinHash.scala | 4 +- .../execution/joins/ShuffledHashJoin.scala | 2 +- .../joins/ShuffledHashOuterJoin.scala | 13 +-- .../sql/execution/rowFormatConverters.scala | 21 +++-- .../org/apache/spark/sql/UnsafeRowSuite.scala | 4 +- .../execution/joins/HashedRelationSuite.scala | 49 ++++++++--- .../spark/unsafe/hash/Murmur3_x86_32.java | 10 ++- 20 files changed, 444 insertions(+), 135 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 6ce03a48e9538..7f08bf7b742dc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -20,10 +20,11 @@ import java.io.IOException; import java.io.OutputStream; -import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.UTF8String; @@ -354,7 +355,7 @@ public double getDouble(int i) { * This method is only supported on UnsafeRows that do not use ObjectPools. */ @Override - public InternalRow copy() { + public UnsafeRow copy() { if (pool != null) { throw new UnsupportedOperationException( "Copy is not supported for UnsafeRows that use object pools"); @@ -404,8 +405,51 @@ public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOExcepti } } + @Override + public int hashCode() { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UnsafeRow) { + UnsafeRow o = (UnsafeRow) other; + return (sizeInBytes == o.sizeInBytes) && + ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, + sizeInBytes); + } + return false; + } + + /** + * Returns the underlying bytes for this UnsafeRow. + */ + public byte[] getBytes() { + if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET + && (((byte[]) baseObject).length == sizeInBytes)) { + return (byte[]) baseObject; + } else { + byte[] bytes = new byte[sizeInBytes]; + PlatformDependent.copyMemory(baseObject, baseOffset, bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, sizeInBytes); + return bytes; + } + } + + // This is for debugging + @Override + public String toString() { + StringBuilder build = new StringBuilder("["); + for (int i = 0; i < sizeInBytes; i += 8) { + build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)); + build.append(','); + } + build.append(']'); + return build.toString(); + } + @Override public boolean anyNull() { - return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); + return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index d1d81c87bb052..39fd6e1bc6d13 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -28,11 +28,10 @@ import org.apache.spark.TaskContext; import org.apache.spark.sql.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ObjectPool; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; @@ -176,12 +175,7 @@ public Iterator sort(Iterator inputIterator) throws IO */ public static boolean supportsSchema(StructType schema) { // TODO: add spilling note to explain why we do this for now: - for (StructField field : schema.fields()) { - if (!UnsafeColumnWriter.canEmbed(field.dataType())) { - return false; - } - } - return true; + return UnsafeProjection.canSupport(schema); } private static final class RowComparator extends RecordComparator { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index b10a3c877434b..4a13b687bf4ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,7 +21,6 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.types._ /** @@ -34,7 +33,23 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, $dataType]" - override def eval(input: InternalRow): Any = input(ordinal) + // Use special getter for primitive types (for UnsafeRow) + override def eval(input: InternalRow): Any = { + if (input.isNullAt(ordinal)) { + null + } else { + dataType match { + case BooleanType => input.getBoolean(ordinal) + case ByteType => input.getByte(ordinal) + case ShortType => input.getShort(ordinal) + case IntegerType | DateType => input.getInt(ordinal) + case LongType | TimestampType => input.getLong(ordinal) + case FloatType => input.getFloat(ordinal) + case DoubleType => input.getDouble(ordinal) + case _ => input.get(ordinal) + } + } + } override def name: String = s"i[$ordinal]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 24b01ea55110e..69758e653eba0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -83,12 +83,42 @@ abstract class UnsafeProjection extends Projection { } object UnsafeProjection { + + /* + * Returns whether UnsafeProjection can support given StructType, Array[DataType] or + * Seq[Expression]. + */ + def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) + def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) + def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) + + /** + * Returns an UnsafeProjection for given StructType. + */ def create(schema: StructType): UnsafeProjection = create(schema.fields.map(_.dataType)) - def create(fields: Seq[DataType]): UnsafeProjection = { + /** + * Returns an UnsafeProjection for given Array of DataTypes. + */ + def create(fields: Array[DataType]): UnsafeProjection = { val exprs = fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true)) + create(exprs) + } + + /** + * Returns an UnsafeProjection for given sequence of Expressions (bounded). + */ + def create(exprs: Seq[Expression]): UnsafeProjection = { GenerateUnsafeProjection.generate(exprs) } + + /** + * Returns an UnsafeProjection for given sequence of Expressions, which will be bound to + * `inputSchema`. + */ + def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { + create(exprs.map(BindReferences.bindReference(_, inputSchema))) + } } /** @@ -96,6 +126,8 @@ object UnsafeProjection { */ case class FromUnsafeProjection(fields: Seq[DataType]) extends Projection { + def this(schema: StructType) = this(schema.fields.map(_.dataType)) + private[this] val expressions = fields.zipWithIndex.map { case (dt, idx) => new BoundReference(idx, dt, true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 7ffdce60d2955..abaa4a6ce86a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -62,7 +62,7 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + val hashed = buildHashRelation(input.iterator) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index ab757fc7de6cd..c9d1a880f4ef4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.joins +import scala.concurrent._ +import scala.concurrent.duration._ + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -26,10 +29,6 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.util.ThreadUtils -import scala.collection.JavaConversions._ -import scala.concurrent._ -import scala.concurrent.duration._ - /** * :: DeveloperApi :: * Performs a outer hash join for two child relations. When the output RDD of this operator is @@ -58,28 +57,11 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil - private[this] lazy val (buildPlan, streamedPlan) = joinType match { - case RightOuter => (left, right) - case LeftOuter => (right, left) - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - - private[this] lazy val (buildKeys, streamedKeys) = joinType match { - case RightOuter => (leftKeys, rightKeys) - case LeftOuter => (rightKeys, leftKeys) - case x => - throw new IllegalArgumentException( - s"BroadcastHashOuterJoin should not take $x as the JoinType") - } - @transient private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - // buildHashTable uses code-generated rows as keys, which are not serializable - val hashed = buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output)) + val hashed = buildHashRelation(input.iterator) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) @@ -89,21 +71,21 @@ case class BroadcastHashOuterJoin( streamedPlan.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow() val hashTable = broadcastRelation.value - val keyGenerator = newProjection(streamedKeys, streamedPlan.output) + val keyGenerator = streamedKeyGenerator joinType match { case LeftOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST)) + leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey)) }) case RightOuter => streamedIter.flatMap(currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow) }) case x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index 2750f58b005ac..f71c0ce352904 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -40,15 +40,14 @@ case class BroadcastLeftSemiJoinHash( val buildIter = right.execute().map(_.copy()).collect().toIterator if (condition.isEmpty) { - // rowKey may be not serializable (from codegen) - val hashSet = buildKeyHashSet(buildIter, copy = true) + val hashSet = buildKeyHashSet(buildIter) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) } } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val hashRelation = buildHashRelation(buildIter) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 60b4266fad8b1..700636966f8be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -44,6 +44,19 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } + override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + + @transient private[this] lazy val resultProjection: Projection = { + if (outputsUnsafeRows) { + UnsafeProjection.create(schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } + override def outputPartitioning: Partitioning = streamed.outputPartitioning override def output: Seq[Attribute] = { @@ -74,6 +87,7 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow + val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -86,11 +100,11 @@ case class BroadcastNestedLoopJoin( val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => - matchedRows += joinedRow(streamedRow, broadcastedRow).copy() + matchedRows += resultProjection(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => - matchedRows += joinedRow(broadcastedRow, streamedRow).copy() + matchedRows += resultProjection(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true includedBroadcastTuples += i case _ => @@ -100,9 +114,9 @@ case class BroadcastNestedLoopJoin( (streamRowMatched, joinType, buildSide) match { case (false, LeftOuter | FullOuter, BuildRight) => - matchedRows += joinedRow(streamedRow, rightNulls).copy() + matchedRows += resultProjection(joinedRow(streamedRow, rightNulls)).copy() case (false, RightOuter | FullOuter, BuildLeft) => - matchedRows += joinedRow(leftNulls, streamedRow).copy() + matchedRows += resultProjection(joinedRow(leftNulls, streamedRow)).copy() case _ => } } @@ -110,12 +124,9 @@ case class BroadcastNestedLoopJoin( } val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) - val allIncludedBroadcastTuples = - if (includedBroadcastTuples.count == 0) { - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - } else { - includedBroadcastTuples.reduce(_ ++ _) - } + val allIncludedBroadcastTuples = includedBroadcastTuples.fold( + new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + )(_ ++ _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -127,8 +138,10 @@ case class BroadcastNestedLoopJoin( while (i < rel.length) { if (!allIncludedBroadcastTuples.contains(i)) { (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i)) - case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls) + case (RightOuter | FullOuter, BuildRight) => + buf += resultProjection(new JoinedRow(leftNulls, rel(i))) + case (LeftOuter | FullOuter, BuildLeft) => + buf += resultProjection(new JoinedRow(rel(i), rightNulls)) case _ => } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index ff85ea3f6a410..ae34409bcfcca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,11 +44,20 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - @transient protected lazy val buildSideKeyGenerator: Projection = - newProjection(buildKeys, buildPlan.output) + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe - @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = - newMutableProjection(streamedKeys, streamedPlan.output) + @transient protected lazy val streamSideKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newMutableProjection(streamedKeys, streamedPlan.output)() + } protected def hashJoin( streamIter: Iterator[InternalRow], @@ -61,8 +70,17 @@ trait HashJoin { // Mutable per row objects. private[this] val joinRow = new JoinedRow2 + private[this] val resultProjection: Projection = { + if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } - private[this] val joinKeys = streamSideKeyGenerator() + private[this] val joinKeys = streamSideKeyGenerator override final def hasNext: Boolean = (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || @@ -74,7 +92,7 @@ trait HashJoin { case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) } currentMatchPosition += 1 - ret + resultProjection(ret) } /** @@ -89,8 +107,9 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() - if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashedRelation.get(joinKeys.currentValue) + val key = joinKeys(currentStreamedRow) + if (!key.anyNull) { + currentHashMatches = hashedRelation.get(key) } } @@ -103,4 +122,12 @@ trait HashJoin { } } } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) + } else { + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 74a7db7761758..6bf2f82954046 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} -import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -38,7 +38,7 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan -override def outputPartitioning: Partitioning = joinType match { + override def outputPartitioning: Partitioning = joinType match { case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) @@ -59,6 +59,49 @@ override def outputPartitioning: Partitioning = joinType match { } } + protected[this] lazy val (buildPlan, streamedPlan) = joinType match { + case RightOuter => (left, right) + case LeftOuter => (right, left) + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + + protected[this] lazy val (buildKeys, streamedKeys) = joinType match { + case RightOuter => (leftKeys, rightKeys) + case LeftOuter => (rightKeys, leftKeys) + case x => + throw new IllegalArgumentException( + s"HashOuterJoin should not take $x as the JoinType") + } + + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && joinType != FullOuter + && UnsafeProjection.canSupport(buildKeys) + && UnsafeProjection.canSupport(self.schema)) + } + + override def outputsUnsafeRows: Boolean = supportUnsafe + override def canProcessUnsafeRows: Boolean = supportUnsafe + + protected[this] def streamedKeyGenerator(): Projection = { + if (supportUnsafe) { + UnsafeProjection.create(streamedKeys, streamedPlan.output) + } else { + newProjection(streamedKeys, streamedPlan.output) + } + } + + @transient private[this] lazy val resultProjection: Projection = { + if (supportUnsafe) { + UnsafeProjection.create(self.schema) + } else { + new Projection { + override def apply(r: InternalRow): InternalRow = r + } + } + } + @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -76,16 +119,20 @@ override def outputPartitioning: Partitioning = joinType match { rightIter: Iterable[InternalRow]): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy() + val temp = if (rightIter != null) { + rightIter.collect { + case r if boundCondition(joinedRow.withRight(r)) => resultProjection(joinedRow).copy() + } + } else { + List.empty } if (temp.isEmpty) { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withRight(rightNullRow).copy :: Nil + resultProjection(joinedRow.withRight(rightNullRow)).copy :: Nil } } ret.iterator @@ -97,17 +144,21 @@ override def outputPartitioning: Partitioning = joinType match { joinedRow: JoinedRow): Iterator[InternalRow] = { val ret: Iterable[InternalRow] = { if (!key.anyNull) { - val temp = leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => - joinedRow.copy() + val temp = if (leftIter != null) { + leftIter.collect { + case l if boundCondition(joinedRow.withLeft(l)) => + resultProjection(joinedRow).copy() + } + } else { + List.empty } if (temp.isEmpty) { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } else { temp } } else { - joinedRow.withLeft(leftNullRow).copy :: Nil + resultProjection(joinedRow.withLeft(leftNullRow)).copy :: Nil } } ret.iterator @@ -159,6 +210,7 @@ override def outputPartitioning: Partitioning = joinType match { } } + // This is only used by FullOuter protected[this] def buildHashTable( iter: Iterator[InternalRow], keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = { @@ -178,4 +230,12 @@ override def outputPartitioning: Partitioning = joinType match { hashTable } + + protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, buildKeys, buildPlan) + } else { + HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 1b983bc3a90f9..7f49264d40354 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -32,34 +32,45 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - @transient protected lazy val rightKeyGenerator: Projection = - newProjection(rightKeys, right.output) + protected[this] def supportUnsafe: Boolean = { + (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) + && UnsafeProjection.canSupport(rightKeys) + && UnsafeProjection.canSupport(left.schema)) + } + + override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = supportUnsafe + + @transient protected lazy val leftKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(leftKeys, left.output) + } else { + newMutableProjection(leftKeys, left.output)() + } - @transient protected lazy val leftKeyGenerator: () => MutableProjection = - newMutableProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator: Projection = + if (supportUnsafe) { + UnsafeProjection.create(rightKeys, right.output) + } else { + newMutableProjection(rightKeys, right.output)() + } @transient private lazy val boundCondition = newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) - protected def buildKeyHashSet( - buildIter: Iterator[InternalRow], - copy: Boolean): java.util.Set[InternalRow] = { + protected def buildKeyHashSet(buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = { val hashSet = new java.util.HashSet[InternalRow]() var currentRow: InternalRow = null // Create a Hash set of buildKeys + val rightKey = rightKeyGenerator while (buildIter.hasNext) { currentRow = buildIter.next() - val rowKey = rightKeyGenerator(currentRow) + val rowKey = rightKey(currentRow) if (!rowKey.anyNull) { val keyExists = hashSet.contains(rowKey) if (!keyExists) { - if (copy) { - hashSet.add(rowKey.copy()) - } else { - // rowKey may be not serializable (from codegen) - hashSet.add(rowKey) - } + hashSet.add(rowKey.copy()) } } } @@ -67,25 +78,34 @@ trait HashSemiJoin { } protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator() - val joinedRow = new JoinedRow + streamIter: Iterator[InternalRow], + hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator streamIter.filter(current => { - lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue) - !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists { - (build: InternalRow) => boundCondition(joinedRow(current, build)) - } + val key = joinKeys(current) + !key.anyNull && hashSet.contains(key) }) } + protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { + if (supportUnsafe) { + UnsafeHashedRelation(buildIter, rightKeys, right) + } else { + HashedRelation(buildIter, newProjection(rightKeys, right.output)) + } + } + protected def hashSemiJoin( streamIter: Iterator[InternalRow], - hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = { - val joinKeys = leftKeyGenerator() + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = leftKeyGenerator val joinedRow = new JoinedRow - streamIter.filter(current => { - !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue) - }) + streamIter.filter { current => + val key = joinKeys(current) + lazy val rowBuffer = hashedRelation.get(key) + !key.anyNull && rowBuffer != null && rowBuffer.exists { + (row: InternalRow) => boundCondition(joinedRow(current, row)) + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 6b51f5d4151d3..8d5731afd59b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.execution.joins -import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} +import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.CompactBuffer @@ -98,7 +99,6 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR } } - // TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. @@ -148,3 +148,80 @@ private[joins] object HashedRelation { } } } + + +/** + * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a + * sequence of values. + * + * TODO(davies): use BytesToBytesMap + */ +private[joins] final class UnsafeHashedRelation( + private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) + extends HashedRelation with Externalizable { + + def this() = this(null) // Needed for serialization + + override def get(key: InternalRow): CompactBuffer[InternalRow] = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + // Thanks to type eraser + hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + } + + override def writeExternal(out: ObjectOutput): Unit = { + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + } + + override def readExternal(in: ObjectInput): Unit = { + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + } +} + +private[joins] object UnsafeHashedRelation { + + def apply( + input: Iterator[InternalRow], + buildKeys: Seq[Expression], + buildPlan: SparkPlan, + sizeEstimate: Int = 64): HashedRelation = { + val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output)) + apply(input, boundedKeys, buildPlan.schema, sizeEstimate) + } + + // Used for tests + def apply( + input: Iterator[InternalRow], + buildKeys: Seq[Expression], + rowSchema: StructType, + sizeEstimate: Int): HashedRelation = { + + // TODO: Use BytesToBytesMap. + val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) + val toUnsafe = UnsafeProjection.create(rowSchema) + val keyGenerator = UnsafeProjection.create(buildKeys) + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + val currentRow = input.next() + val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { + currentRow.asInstanceOf[UnsafeRow] + } else { + toUnsafe(currentRow) + } + val rowKey = keyGenerator(unsafeRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(rowKey.copy(), newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += unsafeRow.copy() + } + } + + new UnsafeHashedRelation(hashTable) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index db5be9f453674..4443455ef11fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -39,6 +39,9 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output + override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 9eaac817d9268..874712a4e739f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -43,10 +43,10 @@ case class LeftSemiJoinHash( protected override def doExecute(): RDD[InternalRow] = { right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) => if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter, copy = false) + val hashSet = buildKeyHashSet(buildIter) hashSemiJoin(streamIter, hashSet) } else { - val hashRelation = HashedRelation(buildIter, rightKeyGenerator) + val hashRelation = buildHashRelation(buildIter) hashSemiJoin(streamIter, hashRelation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5439e10a60b2a..948d0ccebceb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,7 +45,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + val hashed = buildHashRelation(buildIter) hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index ab0a6ad56acde..f54f1edd38ec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,24 +50,25 @@ case class ShuffledHashOuterJoin( // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val keyGenerator = newProjection(leftKeys, left.output) + val hashed = buildHashRelation(rightIter) + val keyGenerator = streamedKeyGenerator() leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey)) }) case RightOuter => - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - val keyGenerator = newProjection(rightKeys, right.output) + val hashed = buildHashRelation(leftIter) + val keyGenerator = streamedKeyGenerator() rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow) }) case FullOuter => + // TODO(davies): use UnsafeRow val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala index 421d510e6782d..29f3beb3cb3c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala @@ -29,6 +29,9 @@ import org.apache.spark.sql.catalyst.rules.Rule */ @DeveloperApi case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { + + require(UnsafeProjection.canSupport(child.schema), s"Cannot convert ${child.schema} to Unsafe") + override def output: Seq[Attribute] = child.output override def outputsUnsafeRows: Boolean = true override def canProcessUnsafeRows: Boolean = false @@ -93,11 +96,19 @@ private[sql] object EnsureRowFormats extends Rule[SparkPlan] { } case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, then convert everything - // to unsafe rows - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + // If this operator's children produce both unsafe and safe rows, + // convert everything unsafe rows if all the schema of them are support by UnsafeRow + if (operator.children.forall(c => UnsafeProjection.canSupport(c.schema))) { + operator.withNewChildren { + operator.children.map { + c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c + } + } + } else { + operator.withNewChildren { + operator.children.map { + c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c + } } } } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 3854dc1b7a3d1..d36e2639376e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -31,7 +31,7 @@ class UnsafeRowSuite extends SparkFunSuite { test("writeToStream") { val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) val arrayBackedUnsafeRow: UnsafeRow = - UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row) + UnsafeProjection.create(Array[DataType](StringType, StringType, IntegerType)).apply(row) assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) val bytesFromArrayBackedRow: Array[Byte] = { val baos = new ByteArrayOutputStream() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9d9858b1c6151..9dd2220f0967e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Projection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.{StructField, StructType, IntegerType} import org.apache.spark.util.collection.CompactBuffer @@ -35,13 +37,13 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) assert(hashed.get(InternalRow(10)) === null) val data2 = CompactBuffer[InternalRow](data(2)) data2 += data(2) - assert(hashed.get(data(2)) == data2) + assert(hashed.get(data(2)) === data2) } test("UniqueKeyHashedRelation") { @@ -49,15 +51,40 @@ class HashedRelationSuite extends SparkFunSuite { val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) - assert(hashed.get(data(0)) == CompactBuffer[InternalRow](data(0))) - assert(hashed.get(data(1)) == CompactBuffer[InternalRow](data(1))) - assert(hashed.get(data(2)) == CompactBuffer[InternalRow](data(2))) + assert(hashed.get(data(0)) === CompactBuffer[InternalRow](data(0))) + assert(hashed.get(data(1)) === CompactBuffer[InternalRow](data(1))) + assert(hashed.get(data(2)) === CompactBuffer[InternalRow](data(2))) assert(hashed.get(InternalRow(10)) === null) val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] - assert(uniqHashed.getValue(data(0)) == data(0)) - assert(uniqHashed.getValue(data(1)) == data(1)) - assert(uniqHashed.getValue(data(2)) == data(2)) - assert(uniqHashed.getValue(InternalRow(10)) == null) + assert(uniqHashed.getValue(data(0)) === data(0)) + assert(uniqHashed.getValue(data(1)) === data(1)) + assert(uniqHashed.getValue(data(2)) === data(2)) + assert(uniqHashed.getValue(InternalRow(10)) === null) + } + + test("UnsafeHashedRelation") { + val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val buildKey = Seq(BoundReference(0, IntegerType, false)) + val schema = StructType(StructField("a", IntegerType, true) :: Nil) + val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1) + assert(hashed.isInstanceOf[UnsafeHashedRelation]) + + val toUnsafeKey = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafeKey(_).copy()).toArray + assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) + + val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) + data2 += unsafeData(2).copy() + assert(hashed.get(unsafeData(2)) === data2) + + val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) + .asInstanceOf[UnsafeHashedRelation] + assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) + assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed2.get(unsafeData(2)) === data2) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 85cd02469adb7..61f483ced3217 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -44,12 +44,16 @@ public int hashInt(int input) { return fmix(h1, 4); } - public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) { + public int hashUnsafeWords(Object base, long offset, int lengthInBytes) { + return hashUnsafeWords(base, offset, lengthInBytes, seed); + } + + public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; int h1 = seed; - for (int offset = 0; offset < lengthInBytes; offset += 4) { - int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + for (int i = 0; i < lengthInBytes; i += 4) { + int halfWord = PlatformDependent.UNSAFE.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } From 8486cd853104255b4eb013860bba793eef4e74e7 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 22 Jul 2015 13:06:01 -0700 Subject: [PATCH 009/219] [SPARK-9224] [MLLIB] OnlineLDA Performance Improvements In-place updates, reduce number of transposes, and vectorize operations in OnlineLDA implementation. Author: Feynman Liang Closes #7454 from feynmanliang/OnlineLDA-perf-improvements and squashes the following commits: 78b0f5a [Feynman Liang] Make in-place variables vals, fix BLAS error 7f62a55 [Feynman Liang] --amend c62cb1e [Feynman Liang] Outer product for stats, revert Range slicing aead650 [Feynman Liang] Range slice, in-place update, reduce transposes --- .../spark/mllib/clustering/LDAOptimizer.scala | 59 +++++++++---------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 8e5154b902d1d..b960ae6c0708d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,15 +19,15 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kron} -import breeze.numerics.{digamma, exp, abs} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.numerics.{abs, digamma, exp} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer -import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector} import org.apache.spark.rdd.RDD /** @@ -370,7 +370,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { iteration += 1 val k = this.k val vocabSize = this.vocabSize - val Elogbeta = dirichletExpectation(lambda) + val Elogbeta = dirichletExpectation(lambda).t val expElogbeta = exp(Elogbeta) val alpha = this.alpha val gammaShape = this.gammaShape @@ -385,41 +385,36 @@ final class OnlineLDAOptimizer extends LDAOptimizer { case v => throw new IllegalArgumentException("Online LDA does not support vector type " + v.getClass) } + if (!ids.isEmpty) { + + // Initialize the variational distribution q(theta|gamma) for the mini-batch + val gammad: BDV[Double] = + new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K + val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K + + val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + var meanchange = 1D + val ctsVector = new BDV[Double](cts) // ids + + // Iterate between gamma and phi until convergence + while (meanchange > 1e-3) { + val lastgamma = gammad.copy + // K K * ids ids + gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha + expElogthetad := exp(digamma(gammad) - digamma(sum(gammad))) + phinorm := expElogbetad * expElogthetad :+ 1e-100 + meanchange = sum(abs(gammad - lastgamma)) / k + } - // Initialize the variational distribution q(theta|gamma) for the mini-batch - var gammad = new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k).t // 1 * K - var Elogthetad = digamma(gammad) - digamma(sum(gammad)) // 1 * K - var expElogthetad = exp(Elogthetad) // 1 * K - val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids - - var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids - var meanchange = 1D - val ctsVector = new BDV[Double](cts).t // 1 * ids - - // Iterate between gamma and phi until convergence - while (meanchange > 1e-3) { - val lastgamma = gammad - // 1*K 1 * ids ids * k - gammad = (expElogthetad :* ((ctsVector / phinorm) * expElogbetad.t)) + alpha - Elogthetad = digamma(gammad) - digamma(sum(gammad)) - expElogthetad = exp(Elogthetad) - phinorm = expElogthetad * expElogbetad + 1e-100 - meanchange = sum(abs(gammad - lastgamma)) / k - } - - val m1 = expElogthetad.t - val m2 = (ctsVector / phinorm).t.toDenseVector - var i = 0 - while (i < ids.size) { - stat(::, ids(i)) := stat(::, ids(i)) + m1 * m2(i) - i += 1 + stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix } } Iterator(stat) } val statsSum: BDM[Double] = stats.reduce(_ += _) - val batchResult = statsSum :* expElogbeta + val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt) From cf21d05f8b5fae52b118fb8846f43d6fda1aea41 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 22 Jul 2015 13:28:09 -0700 Subject: [PATCH 010/219] [SPARK-4366] [SQL] [Follow-up] Fix SqlParser compiling warning. Author: Yin Huai Closes #7588 from yhuai/SPARK-4366-update1 and squashes the following commits: 25f5f36 [Yin Huai] Fix SqlParser Warning. --- .../main/scala/org/apache/spark/sql/catalyst/SqlParser.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index c04bd6cd85187..29cfc064da89a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -271,8 +271,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { lexical.normalizeKeyword(udfName) match { case "sum" => SumDistinct(exprs.head) case "count" => CountDistinct(exprs) - case name => UnresolvedFunction(name, exprs, isDistinct = true) - case _ => throw new AnalysisException(s"function $udfName does not support DISTINCT") + case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) } } | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => From 1aca9c13c144fa336af6afcfa666128bf77c49d4 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 22 Jul 2015 15:07:05 -0700 Subject: [PATCH 011/219] [SPARK-8536] [MLLIB] Generalize OnlineLDAOptimizer to asymmetric document-topic Dirichlet priors Modify `LDA` to take asymmetric document-topic prior distributions and `OnlineLDAOptimizer` to use the asymmetric prior during variational inference. This PR only generalizes `OnlineLDAOptimizer` and the associated `LocalLDAModel`; `EMLDAOptimizer` and `DistributedLDAModel` still only support symmetric `alpha` (checked during `EMLDAOptimizer.initialize`). Author: Feynman Liang Closes #7575 from feynmanliang/SPARK-8536-LDA-asymmetric-priors and squashes the following commits: af8fbb7 [Feynman Liang] Fix merge errors ef5821d [Feynman Liang] Merge remote-tracking branch 'apache/master' into SPARK-8536-LDA-asymmetric-priors 58f1d7b [Feynman Liang] Fix from review feedback a6dcf70 [Feynman Liang] Change docConcentration interface and move LDAOptimizer validation to initialize, add sad path tests 72038ff [Feynman Liang] Add tests referenced against gensim d4284fa [Feynman Liang] Generalize OnlineLDA to asymmetric priors, no tests --- .../apache/spark/mllib/clustering/LDA.scala | 49 +++++++---- .../spark/mllib/clustering/LDAOptimizer.scala | 27 ++++-- .../spark/mllib/clustering/LDASuite.scala | 82 +++++++++++++++++-- 3 files changed, 126 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index a410547a72fda..ab124e6d77c5e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -23,11 +23,10 @@ import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils - /** * :: Experimental :: * @@ -49,14 +48,15 @@ import org.apache.spark.util.Utils class LDA private ( private var k: Int, private var maxIterations: Int, - private var docConcentration: Double, + private var docConcentration: Vector, private var topicConcentration: Double, private var seed: Long, private var checkpointInterval: Int, private var ldaOptimizer: LDAOptimizer) extends Logging { - def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointInterval = 10, ldaOptimizer = new EMLDAOptimizer) + def this() = this(k = 10, maxIterations = 20, docConcentration = Vectors.dense(-1), + topicConcentration = -1, seed = Utils.random.nextLong(), checkpointInterval = 10, + ldaOptimizer = new EMLDAOptimizer) /** * Number of topics to infer. I.e., the number of soft cluster centers. @@ -77,37 +77,50 @@ class LDA private ( * Concentration parameter (commonly named "alpha") for the prior placed on documents' * distributions over topics ("theta"). * - * This is the parameter to a symmetric Dirichlet distribution. + * This is the parameter to a Dirichlet distribution. */ - def getDocConcentration: Double = this.docConcentration + def getDocConcentration: Vector = this.docConcentration /** * Concentration parameter (commonly named "alpha") for the prior placed on documents' * distributions over topics ("theta"). * - * This is the parameter to a symmetric Dirichlet distribution, where larger values - * mean more smoothing (more regularization). + * This is the parameter to a Dirichlet distribution, where larger values mean more smoothing + * (more regularization). * - * If set to -1, then docConcentration is set automatically. - * (default = -1 = automatic) + * If set to a singleton vector Vector(-1), then docConcentration is set automatically. If set to + * singleton vector Vector(t) where t != -1, then t is replicated to a vector of length k during + * [[LDAOptimizer.initialize()]]. Otherwise, the [[docConcentration]] vector must be length k. + * (default = Vector(-1) = automatic) * * Optimizer-specific parameter settings: * - EM - * - Value should be > 1.0 - * - default = (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows - * Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * - Currently only supports symmetric distributions, so all values in the vector should be + * the same. + * - Values should be > 1.0 + * - default = uniformly (50 / k) + 1, where 50/k is common in LDA libraries and +1 follows + * from Asuncion et al. (2009), who recommend a +1 adjustment for EM. * - Online - * - Value should be >= 0 - * - default = (1.0 / k), following the implementation from + * - Values should be >= 0 + * - default = uniformly (1.0 / k), following the implementation from * [[https://github.com/Blei-Lab/onlineldavb]]. */ - def setDocConcentration(docConcentration: Double): this.type = { + def setDocConcentration(docConcentration: Vector): this.type = { this.docConcentration = docConcentration this } + /** Replicates Double to create a symmetric prior */ + def setDocConcentration(docConcentration: Double): this.type = { + this.docConcentration = Vectors.dense(docConcentration) + this + } + /** Alias for [[getDocConcentration]] */ - def getAlpha: Double = getDocConcentration + def getAlpha: Vector = getDocConcentration + + /** Alias for [[setDocConcentration()]] */ + def setAlpha(alpha: Vector): this.type = setDocConcentration(alpha) /** Alias for [[setDocConcentration()]] */ def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index b960ae6c0708d..f4170a3d98dd8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -27,7 +27,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer -import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD /** @@ -95,8 +95,11 @@ final class EMLDAOptimizer extends LDAOptimizer { * Compute bipartite term/doc graph. */ override private[clustering] def initialize(docs: RDD[(Long, Vector)], lda: LDA): LDAOptimizer = { + val docConcentration = lda.getDocConcentration(0) + require({ + lda.getDocConcentration.toArray.forall(_ == docConcentration) + }, "EMLDAOptimizer currently only supports symmetric document-topic priors") - val docConcentration = lda.getDocConcentration val topicConcentration = lda.getTopicConcentration val k = lda.getK @@ -229,10 +232,10 @@ final class OnlineLDAOptimizer extends LDAOptimizer { private var vocabSize: Int = 0 /** alias for docConcentration */ - private var alpha: Double = 0 + private var alpha: Vector = Vectors.dense(0) /** (private[clustering] for debugging) Get docConcentration */ - private[clustering] def getAlpha: Double = alpha + private[clustering] def getAlpha: Vector = alpha /** alias for topicConcentration */ private var eta: Double = 0 @@ -343,7 +346,19 @@ final class OnlineLDAOptimizer extends LDAOptimizer { this.k = lda.getK this.corpusSize = docs.count() this.vocabSize = docs.first()._2.size - this.alpha = if (lda.getDocConcentration == -1) 1.0 / k else lda.getDocConcentration + this.alpha = if (lda.getDocConcentration.size == 1) { + if (lda.getDocConcentration(0) == -1) Vectors.dense(Array.fill(k)(1.0 / k)) + else { + require(lda.getDocConcentration(0) >= 0, s"all entries in alpha must be >=0, got: $alpha") + Vectors.dense(Array.fill(k)(lda.getDocConcentration(0))) + } + } else { + require(lda.getDocConcentration.size == k, s"alpha must have length k, got: $alpha") + lda.getDocConcentration.foreachActive { case (_, x) => + require(x >= 0, s"all entries in alpha must be >= 0, got: $alpha") + } + lda.getDocConcentration + } this.eta = if (lda.getTopicConcentration == -1) 1.0 / k else lda.getTopicConcentration this.randomGenerator = new Random(lda.getSeed) @@ -372,7 +387,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val vocabSize = this.vocabSize val Elogbeta = dirichletExpectation(lambda).t val expElogbeta = exp(Elogbeta) - val alpha = this.alpha + val alpha = this.alpha.toBreeze val gammaShape = this.gammaShape val stats: RDD[BDM[Double]] = batch.mapPartitions { docs => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 721a065658951..da70d9bd7c790 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils @@ -132,22 +132,38 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("setter alias") { val lda = new LDA().setAlpha(2.0).setBeta(3.0) - assert(lda.getAlpha === 2.0) - assert(lda.getDocConcentration === 2.0) + assert(lda.getAlpha.toArray.forall(_ === 2.0)) + assert(lda.getDocConcentration.toArray.forall(_ === 2.0)) assert(lda.getBeta === 3.0) assert(lda.getTopicConcentration === 3.0) } + test("initializing with alpha length != k or 1 fails") { + intercept[IllegalArgumentException] { + val lda = new LDA().setK(2).setAlpha(Vectors.dense(1, 2, 3, 4)) + val corpus = sc.parallelize(tinyCorpus, 2) + lda.run(corpus) + } + } + + test("initializing with elements in alpha < 0 fails") { + intercept[IllegalArgumentException] { + val lda = new LDA().setK(4).setAlpha(Vectors.dense(-1, 2, 3, 4)) + val corpus = sc.parallelize(tinyCorpus, 2) + lda.run(corpus) + } + } + test("OnlineLDAOptimizer initialization") { val lda = new LDA().setK(2) val corpus = sc.parallelize(tinyCorpus, 2) val op = new OnlineLDAOptimizer().initialize(corpus, lda) op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau0(567) - assert(op.getAlpha == 0.5) // default 1.0 / k - assert(op.getEta == 0.5) // default 1.0 / k - assert(op.getKappa == 0.9876) - assert(op.getMiniBatchFraction == 0.123) - assert(op.getTau0 == 567) + assert(op.getAlpha.toArray.forall(_ === 0.5)) // default 1.0 / k + assert(op.getEta === 0.5) // default 1.0 / k + assert(op.getKappa === 0.9876) + assert(op.getMiniBatchFraction === 0.123) + assert(op.getTau0 === 567) } test("OnlineLDAOptimizer one iteration") { @@ -218,6 +234,56 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("OnlineLDAOptimizer with asymmetric prior") { + def toydata: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + + val docs = sc.parallelize(toydata) + val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) + .setGammaShape(1e10) + val lda = new LDA().setK(2) + .setDocConcentration(Vectors.dense(0.00001, 0.1)) + .setTopicConcentration(0.01) + .setMaxIterations(100) + .setOptimizer(op) + .setSeed(12345) + + val ldaModel = lda.run(docs) + val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights) + } + + /* Verify results with Python: + + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(10) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=np.array([0.00001, 0.1]), num_topics=2, update_every=0, passes=100) + lda.print_topics() + + > ['0.167*0 + 0.167*1 + 0.167*2 + 0.167*3 + 0.167*4 + 0.167*5', + '0.167*0 + 0.167*1 + 0.167*2 + 0.167*4 + 0.167*3 + 0.167*5'] + */ + topics.foreach { topic => + assert(topic.forall { case (_, p) => p ~= 0.167 absTol 0.05 }) + } + } + test("model save/load") { // Test for LocalLDAModel. val localModel = new LocalLDAModel(tinyTopics) From fe26584a1f5b472fb2e87aa7259aec822a619a3b Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 22 Jul 2015 15:28:09 -0700 Subject: [PATCH 012/219] [SPARK-9244] Increase some memory defaults There are a few memory limits that people hit often and that we could make higher, especially now that memory sizes have grown. - spark.akka.frameSize: This defaults at 10 but is often hit for map output statuses in large shuffles. This memory is not fully allocated up-front, so we can just make this larger and still not affect jobs that never sent a status that large. We increase it to 128. - spark.executor.memory: Defaults at 512m, which is really small. We increase it to 1g. Author: Matei Zaharia Closes #7586 from mateiz/configs and squashes the following commits: ce0038a [Matei Zaharia] [SPARK-9244] Increase some memory defaults --- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../org/apache/spark/util/AkkaUtils.scala | 2 +- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../apache/spark/ContextCleanerSuite.scala | 4 ++-- .../org/apache/spark/DistributedSuite.scala | 16 +++++++-------- .../scala/org/apache/spark/DriverSuite.scala | 2 +- .../spark/ExternalShuffleServiceSuite.scala | 2 +- .../org/apache/spark/FileServerSuite.scala | 6 +++--- .../apache/spark/JobCancellationSuite.scala | 4 ++-- .../scala/org/apache/spark/ShuffleSuite.scala | 20 +++++++++---------- .../SparkContextSchedulerCreationSuite.scala | 2 +- .../spark/broadcast/BroadcastSuite.scala | 8 ++++---- .../spark/deploy/LogUrlsStandaloneSuite.scala | 4 ++-- .../spark/deploy/SparkSubmitSuite.scala | 4 ++-- .../CoarseGrainedSchedulerBackendSuite.scala | 2 +- .../scheduler/EventLoggingListenerSuite.scala | 2 +- .../spark/scheduler/ReplayListenerSuite.scala | 2 +- .../SparkListenerWithClusterSuite.scala | 2 +- .../KryoSerializerDistributedSuite.scala | 2 +- .../ExternalAppendOnlyMapSuite.scala | 10 +++++----- .../util/collection/ExternalSorterSuite.scala | 14 ++++++------- docs/configuration.md | 16 +++++++-------- .../mllib/util/LocalClusterSparkContext.scala | 2 +- python/pyspark/tests.py | 6 +++--- .../org/apache/spark/repl/ReplSuite.scala | 10 +++++----- .../org/apache/spark/repl/ReplSuite.scala | 8 ++++---- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 4 ++-- 27 files changed, 78 insertions(+), 80 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d00c012d80560..4976e5eb49468 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -471,7 +471,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli .orElse(Option(System.getenv("SPARK_MEM")) .map(warnSparkMem)) .map(Utils.memoryStringToMb) - .getOrElse(512) + .getOrElse(1024) // Convert java options to env vars as a work around // since we can't set env vars directly in sbt. diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index c179833e5b06a..78e7ddc27d1c7 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -128,7 +128,7 @@ private[spark] object AkkaUtils extends Logging { /** Returns the configured max frame size for Akka messages in bytes. */ def maxFrameSizeBytes(conf: SparkConf): Int = { - val frameSizeInMB = conf.getInt("spark.akka.frameSize", 10) + val frameSizeInMB = conf.getInt("spark.akka.frameSize", 128) if (frameSizeInMB > AKKA_MAX_FRAME_SIZE_IN_MB) { throw new IllegalArgumentException( s"spark.akka.frameSize should not be greater than $AKKA_MAX_FRAME_SIZE_IN_MB MB") diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 1b04a3b1cff0e..e948ca33471a4 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1783,7 +1783,7 @@ public void testGuavaOptional() { // Stop the context created in setUp() and start a local-cluster one, to force usage of the // assembly. sc.stop(); - JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,512]", "JavaAPISuite"); + JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,1024]", "JavaAPISuite"); try { JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3); JavaRDD> rdd2 = rdd1.map( diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 501fe186bfd7c..26858ef2774fc 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -292,7 +292,7 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { sc.stop() val conf2 = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") @@ -370,7 +370,7 @@ class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[Sor sc.stop() val conf2 = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("ContextCleanerSuite") .set("spark.cleaner.referenceTracking.blocking", "true") .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 2300bcff4f118..600c1403b0344 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -29,7 +29,7 @@ class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() { class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { - val clusterUrl = "local-cluster[2,1,512]" + val clusterUrl = "local-cluster[2,1,1024]" test("task throws not serializable exception") { // Ensures that executors do not crash when an exn is not serializable. If executors crash, @@ -40,7 +40,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val numSlaves = 3 val numPartitions = 10 - sc = new SparkContext("local-cluster[%s,1,512]".format(numSlaves), "test") + sc = new SparkContext("local-cluster[%s,1,1024]".format(numSlaves), "test") val data = sc.parallelize(1 to 100, numPartitions). map(x => throw new NotSerializableExn(new NotSerializableClass)) intercept[SparkException] { @@ -50,16 +50,16 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } test("local-cluster format") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test") + sc = new SparkContext("local-cluster[2 , 1 , 1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[2, 1, 512]", "test") + sc = new SparkContext("local-cluster[2, 1, 1024]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() - sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test") + sc = new SparkContext("local-cluster[ 2, 1, 1024 ]", "test") assert(sc.parallelize(1 to 2, 2).count() == 2) resetSparkContext() } @@ -276,7 +276,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex DistributedSuite.amMaster = true // Using more than two nodes so we don't have a symmetric communication pattern and might // cache a partially correct list of peers. - sc = new SparkContext("local-cluster[3,1,512]", "test") + sc = new SparkContext("local-cluster[3,1,1024]", "test") for (i <- 1 to 3) { val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) @@ -294,7 +294,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex test("unpersist RDDs") { DistributedSuite.amMaster = true - sc = new SparkContext("local-cluster[3,1,512]", "test") + sc = new SparkContext("local-cluster[3,1,1024]", "test") val data = sc.parallelize(Seq(true, false, false, false), 4) data.persist(StorageLevel.MEMORY_ONLY_2) data.count diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index b2262033ca238..454b7e607a51b 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -29,7 +29,7 @@ class DriverSuite extends SparkFunSuite with Timeouts { ignore("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val masters = Table("master", "local", "local-cluster[2,1,512]") + val masters = Table("master", "local", "local-cluster[2,1,1024]") forAll(masters) { (master: String) => val process = Utils.executeCommand( Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 140012226fdbb..c38d70252add1 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -51,7 +51,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { // This test ensures that the external shuffle service is actually in use for the other tests. test("using external shuffle service") { - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 876418aa13029..1255e71af6c0b 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -139,7 +139,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test("Distributing files on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addFile(tmpFile.toString) val testData = Array((1, 1), (1, 1), (2, 1), (3, 5), (2, 2), (3, 0)) val result = sc.parallelize(testData).reduceByKey { @@ -153,7 +153,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addJar(tmpJarUrl) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => @@ -164,7 +164,7 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } test ("Dynamically adding JARS on a standalone cluster using local: URL") { - sc = new SparkContext("local-cluster[1,1,512]", "test", newConf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", newConf) sc.addJar(tmpJarUrl.replace("file", "local")) val testData = Array((1, 1)) sc.parallelize(testData).foreach { x => diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 340a9e327107e..1168eb0b802f2 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -64,7 +64,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft test("cluster mode, FIFO scheduler") { val conf = new SparkConf().set("spark.scheduler.mode", "FIFO") - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. @@ -75,7 +75,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft val conf = new SparkConf().set("spark.scheduler.mode", "FAIR") val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() conf.set("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) testCount() testTake() // Make sure we can still launch tasks. diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index b68102bfb949f..d91b799ecfc08 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -47,7 +47,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } test("shuffle non-zero block size") { - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val NUM_BLOCKS = 3 val a = sc.parallelize(1 to 10, 2) @@ -73,7 +73,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("shuffle serializer") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (x, new NonJavaSerializableClass(x * 2)) @@ -89,7 +89,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("zero sized blocks") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys val NUM_BLOCKS = 201 @@ -116,7 +116,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("zero sized blocks without kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // 201 partitions (greater than "spark.shuffle.sort.bypassMergeThreshold") from 4 keys val NUM_BLOCKS = 201 @@ -141,7 +141,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("shuffle on mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -154,7 +154,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sorting on mutable pairs") { // This is not in SortingSuite because of the local cluster setup. // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22)) val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2) @@ -168,7 +168,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("cogroup using mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3")) @@ -195,7 +195,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("subtract mutable pairs") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) def p[T1, T2](_1: T1, _2: T2): MutablePair[T1, T2] = MutablePair(_1, _2) val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33)) val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22")) @@ -210,7 +210,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sort with Java non serializable class - Kryo") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks val myConf = conf.clone().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - sc = new SparkContext("local-cluster[2,1,512]", "test", myConf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", myConf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) @@ -223,7 +223,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC test("sort with Java non serializable class - Java") { // Use a local cluster with 2 processes to make sure there are both local and remote blocks - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val a = sc.parallelize(1 to 10, 2) val b = a.map { x => (new NonJavaSerializableClass(x), x) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index dba46f101c580..e5a14a69ef05f 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -123,7 +123,7 @@ class SparkContextSchedulerCreationSuite } test("local-cluster") { - createTaskScheduler("local-cluster[3, 14, 512]").backend match { + createTaskScheduler("local-cluster[3, 14, 1024]").backend match { case s: SparkDeploySchedulerBackend => // OK case _ => fail() } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index c054c718075f8..48e74f06f79b1 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -69,7 +69,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val conf = httpConf.clone conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -97,7 +97,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val conf = torrentConf.clone conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) @@ -125,7 +125,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Test Lazy Broadcast variables with TorrentBroadcast") { val numSlaves = 2 val conf = torrentConf.clone - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) val rdd = sc.parallelize(1 to numSlaves) val results = new DummyBroadcastClass(rdd).doSomething() @@ -308,7 +308,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc = if (distributed) { val _sc = - new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) // Wait until all salves are up _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 10000) _sc diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index ddc92814c0acf..cbd2aee10c0e2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -33,7 +33,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { private val WAIT_TIMEOUT_MILLIS = 10000 test("verify that correct log urls get propagated from workers") { - sc = new SparkContext("local-cluster[2,1,512]", "test") + sc = new SparkContext("local-cluster[2,1,1024]", "test") val listener = new SaveExecutorInfo sc.addSparkListener(listener) @@ -66,7 +66,7 @@ class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { } val conf = new MySparkConf().set( "spark.extraListeners", classOf[SaveExecutorInfo].getName) - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // Trigger a job so that executors get added sc.parallelize(1 to 100, 4).map(_.toString).count() diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 343d28eef8359..aa78bfe30974c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -337,7 +337,7 @@ class SparkSubmitSuite val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -352,7 +352,7 @@ class SparkSubmitSuite val args = Seq( "--class", JarCreationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", "--packages", Seq(main, dep).mkString(","), "--repositories", repo, "--conf", "spark.ui.enabled=false", diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 34145691153ce..eef6aafa624ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -26,7 +26,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo val conf = new SparkConf conf.set("spark.akka.frameSize", "1") conf.set("spark.default.parallelism", "1") - sc = new SparkContext("local-cluster[2 , 1 , 512]", "test", conf) + sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) val frameSize = AkkaUtils.maxFrameSizeBytes(sc.conf) val buffer = new SerializableBuffer(java.nio.ByteBuffer.allocate(2 * frameSize)) val larger = sc.parallelize(Seq(buffer)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index f681f21b6205e..5cb2d4225d281 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -180,7 +180,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // into SPARK-6688. val conf = getLoggingConf(testDirPath, compressionCodec) .set("spark.hadoop.fs.defaultFS", "unsupported://example.com") - val sc = new SparkContext("local-cluster[2,2,512]", "test", conf) + val sc = new SparkContext("local-cluster[2,2,1024]", "test", conf) assert(sc.eventLogger.isDefined) val eventLogger = sc.eventLogger.get val eventLogPath = eventLogger.logPath diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 4e3defb43a021..103fc19369c97 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -102,7 +102,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { fileSystem.mkdirs(logDirPath) val conf = EventLoggingListenerSuite.getLoggingConf(logDirPath, codecName) - val sc = new SparkContext("local-cluster[2,1,512]", "Test replay", conf) + val sc = new SparkContext("local-cluster[2,1,1024]", "Test replay", conf) // Run a few jobs sc.parallelize(1 to 100, 1).count() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala index d97fba00976d2..d1e23ed527ff1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerWithClusterSuite.scala @@ -34,7 +34,7 @@ class SparkListenerWithClusterSuite extends SparkFunSuite with LocalSparkContext val WAIT_TIMEOUT_MILLIS = 10000 before { - sc = new SparkContext("local-cluster[2,1,512]", "SparkListenerSuite") + sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite") } test("SparkListener sends executor added message") { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 353b97469cd11..935a091f14f9b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -35,7 +35,7 @@ class KryoSerializerDistributedSuite extends SparkFunSuite { val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) - val sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + val sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) val original = Thread.currentThread.getContextClassLoader val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) SparkEnv.get.serializer.setDefaultClassLoader(loader) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 79eba61a87251..9c362f0de7076 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -244,7 +244,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private def testSimpleSpilling(codec: Option[String] = None): Unit = { val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -292,7 +292,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[String] val collisionPairs = Seq( @@ -341,7 +341,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.0001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes @@ -366,7 +366,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] (1 to 100000).foreach { i => map.insert(i, i) } @@ -383,7 +383,7 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val map = createExternalMap[Int] map.insertAll((1 to 100000).iterator.map(i => (i, i))) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 9cefa612f5491..986cd8623d145 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -176,7 +176,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def testSpillingInLocalCluster(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // reduceByKey - should spill ~8 times val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -254,7 +254,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.001") conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager") - sc = new SparkContext("local-cluster[2,1,512]", "test", conf) + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) // reduceByKey - should spill ~4 times per executor val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i)) @@ -554,7 +554,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -611,7 +611,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with many hash collisions") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.0001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) @@ -634,7 +634,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with hash collisions using the Int.MaxValue key") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i @@ -658,7 +658,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { test("spilling with null keys and values") { val conf = createSparkConf(true, false) conf.set("spark.shuffle.memoryFraction", "0.001") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -695,7 +695,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { def sortWithoutBreakingSortingContracts(conf: SparkConf) { conf.set("spark.shuffle.memoryFraction", "0.01") conf.set("spark.shuffle.manager", "sort") - sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) // Using wrongOrdering to show integer overflow introduced exception. val rand = new Random(100L) diff --git a/docs/configuration.md b/docs/configuration.md index 8a186ee51c1ca..fea259204ae68 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -31,7 +31,6 @@ which can help detect bugs that only exist when we run in a distributed context. val conf = new SparkConf() .setMaster("local[2]") .setAppName("CountingSheep") - .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} @@ -84,7 +83,7 @@ Running `./bin/spark-submit --help` will show the entire list of these options. each line consists of a key and a value separated by whitespace. For example: spark.master spark://5.6.7.8:7077 - spark.executor.memory 512m + spark.executor.memory 4g spark.eventLog.enabled true spark.serializer org.apache.spark.serializer.KryoSerializer @@ -150,10 +149,9 @@ of the most common options to set are: spark.executor.memory - 512m + 1g - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). + Amount of memory to use per executor process (e.g. 2g, 8g). @@ -886,11 +884,11 @@ Apart from these, the following properties are also available, and may be useful spark.akka.frameSize - 10 + 128 - Maximum message size to allow in "control plane" communication (for serialized tasks and task - results), in MB. Increase this if your tasks need to send back large results to the driver - (e.g. using collect() on a large dataset). + Maximum message size to allow in "control plane" communication; generally only applies to map + output size information sent between executors and the driver. Increase this if you are running + jobs with many thousands of map and reduce tasks and see messages about the frame size. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala index 5e9101cdd3804..525ab68c7921a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala @@ -26,7 +26,7 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => override def beforeAll() { val conf = new SparkConf() - .setMaster("local-cluster[2, 1, 512]") + .setMaster("local-cluster[2, 1, 1024]") .setAppName("test-cluster") .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data sc = new SparkContext(conf) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 5be9937cb04b2..8bfed074c9052 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1823,7 +1823,7 @@ def test_module_dependency_on_cluster(self): | return x + 1 """) proc = subprocess.Popen([self.sparkSubmit, "--py-files", zip, "--master", - "local-cluster[1,1,512]", script], + "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) @@ -1857,7 +1857,7 @@ def test_package_dependency_on_cluster(self): self.create_spark_package("a:mylib:0.1") proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories", "file:" + self.programDir, "--master", - "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) + "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) self.assertIn("[2, 3, 4]", out.decode('utf-8')) @@ -1876,7 +1876,7 @@ def test_single_script_on_cluster(self): # this will fail if you have different spark.executor.memory # in conf/spark-defaults.conf proc = subprocess.Popen( - [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script], + [self.sparkSubmit, "--master", "local-cluster[1,1,1024]", script], stdout=subprocess.PIPE) out, err = proc.communicate() self.assertEqual(0, proc.returncode) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index f150fec7db945..5674dcd669bee 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -211,7 +211,7 @@ class ReplSuite extends SparkFunSuite { } test("local-cluster mode") { - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |var v = 7 |def getV() = v @@ -233,7 +233,7 @@ class ReplSuite extends SparkFunSuite { } test("SPARK-1199 two instances of same class don't type check.") { - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |case class Sum(exp: String, exp2: String) |val a = Sum("A", "B") @@ -256,7 +256,7 @@ class ReplSuite extends SparkFunSuite { test("SPARK-2576 importing SQLContext.implicits._") { // We need to use local-cluster to test this case. - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |val sqlContext = new org.apache.spark.sql.SQLContext(sc) |import sqlContext.implicits._ @@ -325,9 +325,9 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("Exception", output) assertContains("ret: Array[Foo] = Array(Foo(1),", output) } - + test("collecting objects of class defined in repl - shuffling") { - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |case class Foo(i: Int) |val list = List((1, Foo(1)), (1, Foo(2))) diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index e1cee97de32bc..bf8997998e00d 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -209,7 +209,7 @@ class ReplSuite extends SparkFunSuite { } test("local-cluster mode") { - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |var v = 7 |def getV() = v @@ -231,7 +231,7 @@ class ReplSuite extends SparkFunSuite { } test("SPARK-1199 two instances of same class don't type check.") { - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |case class Sum(exp: String, exp2: String) |val a = Sum("A", "B") @@ -254,7 +254,7 @@ class ReplSuite extends SparkFunSuite { test("SPARK-2576 importing SQLContext.createDataFrame.") { // We need to use local-cluster to test this case. - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |val sqlContext = new org.apache.spark.sql.SQLContext(sc) |import sqlContext.implicits._ @@ -314,7 +314,7 @@ class ReplSuite extends SparkFunSuite { } test("collecting objects of class defined in repl - shuffling") { - val output = runInterpreter("local-cluster[1,1,512]", + val output = runInterpreter("local-cluster[1,1,1024]", """ |case class Foo(i: Int) |val list = List((1, Foo(1)), (1, Foo(2))) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index bee2ecbedb244..72b35959a491b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -53,7 +53,7 @@ class HiveSparkSubmitSuite val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), "--name", "SparkSubmitClassLoaderTest", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", "--jars", jarsString, unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") runSparkSubmit(args) @@ -64,7 +64,7 @@ class HiveSparkSubmitSuite val args = Seq( "--class", SparkSQLConfTest.getClass.getName.stripSuffix("$"), "--name", "SparkSQLConfTest", - "--master", "local-cluster[2,1,512]", + "--master", "local-cluster[2,1,1024]", unusedJar.toString) runSparkSubmit(args) } From 798dff7b4baa952c609725b852bcb6a9c9e5a317 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Wed, 22 Jul 2015 15:54:08 -0700 Subject: [PATCH 013/219] [SPARK-8975] [STREAMING] Adds a mechanism to send a new rate from the driver to the block generator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit First step for [SPARK-7398](https://issues.apache.org/jira/browse/SPARK-7398). tdas huitseeker Author: Iulian Dragos Author: François Garillot Closes #7471 from dragos/topic/streaming-bp/dynamic-rate and squashes the following commits: 8941cf9 [Iulian Dragos] Renames and other nitpicks. 162d9e5 [Iulian Dragos] Use Reflection for accessing truly private `executor` method and use the listener bus to know when receivers have registered (`onStart` is called before receivers have registered, leading to flaky behavior). 210f495 [Iulian Dragos] Revert "Added a few tests that measure the receiver’s rate." 0c51959 [Iulian Dragos] Added a few tests that measure the receiver’s rate. 261a051 [Iulian Dragos] - removed field to hold the current rate limit in rate limiter - made rate limit a Long and default to Long.MaxValue (consequence of the above) - removed custom `waitUntil` and replaced it by `eventually` cd1397d [Iulian Dragos] Add a test for the propagation of a new rate limit from driver to receivers. 6369b30 [Iulian Dragos] Merge pull request #15 from huitseeker/SPARK-8975 d15de42 [François Garillot] [SPARK-8975][Streaming] Adds Ratelimiter unit tests w.r.t. spark.streaming.receiver.maxRate 4721c7d [François Garillot] [SPARK-8975][Streaming] Add a mechanism to send a new rate from the driver to the block generator --- .../streaming/receiver/RateLimiter.scala | 30 +++++++-- .../spark/streaming/receiver/Receiver.scala | 2 +- .../streaming/receiver/ReceiverMessage.scala | 3 +- .../receiver/ReceiverSupervisor.scala | 3 + .../receiver/ReceiverSupervisorImpl.scala | 6 ++ .../streaming/scheduler/ReceiverTracker.scala | 9 ++- .../streaming/receiver/RateLimiterSuite.scala | 46 ++++++++++++++ .../scheduler/ReceiverTrackerSuite.scala | 62 +++++++++++++++++++ 8 files changed, 153 insertions(+), 8 deletions(-) create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 8df542b367d27..f663def4c0511 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -34,12 +34,32 @@ import org.apache.spark.{Logging, SparkConf} */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { - private val desiredRate = conf.getInt("spark.streaming.receiver.maxRate", 0) - private lazy val rateLimiter = GuavaRateLimiter.create(desiredRate) + // treated as an upper limit + private val maxRateLimit = conf.getLong("spark.streaming.receiver.maxRate", Long.MaxValue) + private lazy val rateLimiter = GuavaRateLimiter.create(maxRateLimit.toDouble) def waitToPush() { - if (desiredRate > 0) { - rateLimiter.acquire() - } + rateLimiter.acquire() } + + /** + * Return the current rate limit. If no limit has been set so far, it returns {{{Long.MaxValue}}}. + */ + def getCurrentLimit: Long = + rateLimiter.getRate.toLong + + /** + * Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by + * {{{spark.streaming.receiver.maxRate}}}, even if `newRate` is higher than that. + * + * @param newRate A new rate in events per second. It has no effect if it's 0 or negative. + */ + private[receiver] def updateRate(newRate: Long): Unit = + if (newRate > 0) { + if (maxRateLimit > 0) { + rateLimiter.setRate(newRate.min(maxRateLimit)) + } else { + rateLimiter.setRate(newRate) + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 5b5a3fe648602..7504fa44d9fae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -271,7 +271,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable } /** Get the attached executor. */ - private def executor = { + private def executor: ReceiverSupervisor = { assert(executor_ != null, "Executor has not been attached to this receiver") executor_ } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala index 7bf3c33319491..1eb55affaa9d0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala @@ -23,4 +23,5 @@ import org.apache.spark.streaming.Time private[streaming] sealed trait ReceiverMessage extends Serializable private[streaming] object StopReceiver extends ReceiverMessage private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage - +private[streaming] case class UpdateRateLimit(elementsPerSecond: Long) + extends ReceiverMessage diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 6467029a277b2..a7c220f426ecf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -59,6 +59,9 @@ private[streaming] abstract class ReceiverSupervisor( /** Time between a receiver is stopped and started again */ private val defaultRestartDelay = conf.getInt("spark.streaming.receiverRestartDelay", 2000) + /** The current maximum rate limit for this receiver. */ + private[streaming] def getCurrentRateLimit: Option[Long] = None + /** Exception associated with the stopping of the receiver */ @volatile protected var stoppingError: Throwable = null diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index f6ba66b3ae036..2f6841ee8879c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -77,6 +77,9 @@ private[streaming] class ReceiverSupervisorImpl( case CleanupOldBlocks(threshTime) => logDebug("Received delete old batch signal") cleanupOldBlocks(threshTime) + case UpdateRateLimit(eps) => + logInfo(s"Received a new rate limit: $eps.") + blockGenerator.updateRate(eps) } }) @@ -98,6 +101,9 @@ private[streaming] class ReceiverSupervisorImpl( } }, streamId, env.conf) + override private[streaming] def getCurrentRateLimit: Option[Long] = + Some(blockGenerator.getCurrentLimit) + /** Push a single record of received data into block generator. */ def pushSingle(data: Any) { blockGenerator.addData(data) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6910d81d9866e..9cc6ffcd12f61 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, - StopReceiver} + StopReceiver, UpdateRateLimit} import org.apache.spark.util.SerializableConfiguration /** @@ -226,6 +226,13 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logError(s"Deregistered receiver for stream $streamId: $messageWithError") } + /** Update a receiver's maximum ingestion rate */ + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { + for (info <- receiverInfo.get(streamUID); eP <- Option(info.endpoint)) { + eP.send(UpdateRateLimit(newRate)) + } + } + /** Add new blocks for the given stream */ private def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { receivedBlockTracker.addBlock(receivedBlockInfo) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala new file mode 100644 index 0000000000000..c6330eb3673fb --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/RateLimiterSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.receiver + +import org.apache.spark.SparkConf +import org.apache.spark.SparkFunSuite + +/** Testsuite for testing the network receiver behavior */ +class RateLimiterSuite extends SparkFunSuite { + + test("rate limiter initializes even without a maxRate set") { + val conf = new SparkConf() + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit == 105) + } + + test("rate limiter updates when below maxRate") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "110") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit == 105) + } + + test("rate limiter stays below maxRate despite large updates") { + val conf = new SparkConf().set("spark.streaming.receiver.maxRate", "100") + val rateLimiter = new RateLimiter(conf){} + rateLimiter.updateRate(105) + assert(rateLimiter.getCurrentLimit === 100) + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index a6e783861dbe6..aadb7231757b8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -17,11 +17,17 @@ package org.apache.spark.streaming.scheduler +import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming._ import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.receiver._ import org.apache.spark.util.Utils +import org.apache.spark.streaming.dstream.InputDStream +import scala.reflect.ClassTag +import org.apache.spark.streaming.dstream.ReceiverInputDStream /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { @@ -72,8 +78,64 @@ class ReceiverTrackerSuite extends TestSuiteBase { assert(locations(0).length === 1) assert(locations(3).length === 1) } + + test("Receiver tracker - propagates rate limit") { + object ReceiverStartedWaiter extends StreamingListener { + @volatile + var started = false + + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + started = true + } + } + + ssc.addStreamingListener(ReceiverStartedWaiter) + ssc.scheduler.listenerBus.start(ssc.sc) + + val newRateLimit = 100L + val inputDStream = new RateLimitInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + + // we wait until the Receiver has registered with the tracker, + // otherwise our rate update is lost + eventually(timeout(5 seconds)) { + assert(ReceiverStartedWaiter.started) + } + tracker.sendRateUpdate(inputDStream.id, newRateLimit) + // this is an async message, we need to wait a bit for it to be processed + eventually(timeout(3 seconds)) { + assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + } + } } +/** An input DStream with a hard-coded receiver that gives access to internals for testing. */ +private class RateLimitInputDStream(@transient ssc_ : StreamingContext) + extends ReceiverInputDStream[Int](ssc_) { + + override def getReceiver(): DummyReceiver = SingletonDummyReceiver + + def getCurrentRateLimit: Option[Long] = { + invokeExecutorMethod.getCurrentRateLimit + } + + private def invokeExecutorMethod: ReceiverSupervisor = { + val c = classOf[Receiver[_]] + val ex = c.getDeclaredMethod("executor") + ex.setAccessible(true) + ex.invoke(SingletonDummyReceiver).asInstanceOf[ReceiverSupervisor] + } +} + +/** + * A Receiver as an object so we can read its rate limit. + * + * @note It's necessary to be a top-level object, or else serialization would create another + * one on the executor side and we won't be able to read its rate limit. + */ +private object SingletonDummyReceiver extends DummyReceiver + /** * Dummy receiver implementation */ From 430cd7815dc7875edd126af4b90752ba8a380cf2 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Wed, 22 Jul 2015 16:15:44 -0700 Subject: [PATCH 014/219] [SPARK-9180] fix spark-shell to accept --name option This patch fixes [[SPARK-9180]](https://issues.apache.org/jira/browse/SPARK-9180). Users can now set the app name of spark-shell using `spark-shell --name "whatever"`. Author: Kenichi Maehashi Closes #7512 from kmaehashi/fix-spark-shell-app-name and squashes the following commits: e24991a [Kenichi Maehashi] use setIfMissing instead of setAppName 18aa4ad [Kenichi Maehashi] fix spark-shell to accept --name option --- bin/spark-shell | 4 ++-- bin/spark-shell2.cmd | 2 +- .../src/main/scala/org/apache/spark/repl/SparkILoop.scala | 2 +- .../src/main/scala/org/apache/spark/repl/Main.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bin/spark-shell b/bin/spark-shell index a6dc863d83fc6..00ab7afd118b5 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -47,11 +47,11 @@ function main() { # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "$@" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main --name "Spark shell" "$@" fi } diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd index 251309d67f860..b9b0f510d7f5d 100644 --- a/bin/spark-shell2.cmd +++ b/bin/spark-shell2.cmd @@ -32,4 +32,4 @@ if "x%SPARK_SUBMIT_OPTS%"=="x" ( set SPARK_SUBMIT_OPTS="%SPARK_SUBMIT_OPTS% -Dscala.usejavacp=true" :run_shell -%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main %* +%SPARK_HOME%\bin\spark-submit2.cmd --class org.apache.spark.repl.Main --name "Spark shell" %* diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 8f7f9074d3f03..8130868fe1487 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -1008,9 +1008,9 @@ class SparkILoop( val jars = SparkILoop.getAddedJars val conf = new SparkConf() .setMaster(getMaster()) - .setAppName("Spark shell") .setJars(jars) .set("spark.repl.class.uri", intp.classServerUri) + .setIfMissing("spark.app.name", "Spark shell") if (execUri != null) { conf.set("spark.executor.uri", execUri) } diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala index eed4a379afa60..be31eb2eda546 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala @@ -65,9 +65,9 @@ object Main extends Logging { val jars = getAddedJars val conf = new SparkConf() .setMaster(getMaster) - .setAppName("Spark shell") .setJars(jars) .set("spark.repl.class.uri", classServer.uri) + .setIfMissing("spark.app.name", "Spark shell") logInfo("Spark class server started at " + classServer.uri) if (execUri != null) { conf.set("spark.executor.uri", execUri) From 5307c9d3f7a35c0276b72e743e3a62a44d2bd0f5 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 22 Jul 2015 17:22:12 -0700 Subject: [PATCH 015/219] [SPARK-9223] [PYSPARK] [MLLIB] Support model save/load in LDA Since save / load has been merged in LDA, it takes no time to write the wrappers in Python as well. Author: MechCoder Closes #7587 from MechCoder/python_lda_save_load and squashes the following commits: c8e4ea7 [MechCoder] [SPARK-9223] [PySpark] Support model save/load in LDA --- python/pyspark/mllib/clustering.py | 43 +++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 8a92f6911c24b..58ad99d46e23b 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -20,6 +20,7 @@ if sys.version > '3': xrange = range + basestring = str from math import exp, log @@ -579,7 +580,7 @@ class LDAModel(JavaModelWrapper): Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. >>> from pyspark.mllib.linalg import Vectors - >>> from numpy.testing import assert_almost_equal + >>> from numpy.testing import assert_almost_equal, assert_equal >>> data = [ ... [1, Vectors.dense([0.0, 1.0])], ... [2, SparseVector(2, {0: 1.0})], @@ -591,6 +592,19 @@ class LDAModel(JavaModelWrapper): >>> topics = model.topicsMatrix() >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) >>> assert_almost_equal(topics, topics_expect, 1) + + >>> import os, tempfile + >>> from shutil import rmtree + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = LDAModel.load(sc, path) + >>> assert_equal(sameModel.topicsMatrix(), model.topicsMatrix()) + >>> sameModel.vocabSize() == model.vocabSize() + True + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def topicsMatrix(self): @@ -601,6 +615,33 @@ def vocabSize(self): """Vocabulary size (number of terms or terms in the vocabulary)""" return self.call("vocabSize") + def save(self, sc, path): + """Save the LDAModel on to disk. + + :param sc: SparkContext + :param path: str, path to where the model needs to be stored. + """ + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + self._java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + """Load the LDAModel from disk. + + :param sc: SparkContext + :param path: str, path to where the model is stored. + """ + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) + java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load( + sc._jsc.sc(), path) + return cls(java_model) + class LDA(object): From a721ee52705100dbd7852f80f92cde4375517e48 Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Wed, 22 Jul 2015 17:35:05 -0700 Subject: [PATCH 016/219] [SPARK-8484] [ML] Added TrainValidationSplit for hyper-parameter tuning. - [X] Added TrainValidationSplit for hyper-parameter tuning. It randomly splits the input dataset into train and validation and use evaluation metric on the validation set to select the best model. It should be similar to CrossValidator, but simpler and less expensive. - [X] Simplified replacement of https://github.com/apache/spark/pull/6996 Author: martinzapletal Closes #7337 from zapletal-martin/SPARK-8484-TrainValidationSplit and squashes the following commits: cafc949 [martinzapletal] Review comments https://github.com/apache/spark/pull/7337. 511b398 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-8484-TrainValidationSplit f4fc9c4 [martinzapletal] SPARK-8484 Resolved feedback to https://github.com/apache/spark/pull/7337 00c4f5a [martinzapletal] SPARK-8484. Styling. d699506 [martinzapletal] SPARK-8484. Styling. 93ed2ee [martinzapletal] Styling. 3bc1853 [martinzapletal] SPARK-8484. Styling. 2aa6f43 [martinzapletal] SPARK-8484. Added TrainValidationSplit for hyper-parameter tuning. It randomly splits the input dataset into train and validation and use evaluation metric on the validation set to select the best model. 21662eb [martinzapletal] SPARK-8484. Added TrainValidationSplit for hyper-parameter tuning. It randomly splits the input dataset into train and validation and use evaluation metric on the validation set to select the best model. --- .../spark/ml/tuning/CrossValidator.scala | 33 +--- .../ml/tuning/TrainValidationSplit.scala | 168 ++++++++++++++++++ .../spark/ml/tuning/ValidatorParams.scala | 60 +++++++ .../ml/tuning/TrainValidationSplitSuite.scala | 139 +++++++++++++++ 4 files changed, 368 insertions(+), 32 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index e2444ab65b43b..f979319cc4b58 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -32,38 +32,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends Params { - - /** - * param for the estimator to be cross-validated - * @group param - */ - val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") - - /** @group getParam */ - def getEstimator: Estimator[_] = $(estimator) - - /** - * param for estimator param maps - * @group param - */ - val estimatorParamMaps: Param[Array[ParamMap]] = - new Param(this, "estimatorParamMaps", "param maps for the estimator") - - /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) - - /** - * param for the evaluator used to select hyper-parameters that maximize the cross-validated - * metric - * @group param - */ - val evaluator: Param[Evaluator] = new Param(this, "evaluator", - "evaluator used to select hyper-parameters that maximize the cross-validated metric") - - /** @group getParam */ - def getEvaluator: Evaluator = $(evaluator) - +private[ml] trait CrossValidatorParams extends ValidatorParams { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala new file mode 100644 index 0000000000000..c0edc730b6fd6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType + +/** + * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. + */ +private[ml] trait TrainValidationSplitParams extends ValidatorParams { + /** + * Param for ratio between train and validation data. Must be between 0 and 1. + * Default: 0.75 + * @group param + */ + val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio", + "ratio between training set and validation set (>= 0 && <= 1)", ParamValidators.inRange(0, 1)) + + /** @group getParam */ + def getTrainRatio: Double = $(trainRatio) + + setDefault(trainRatio -> 0.75) +} + +/** + * :: Experimental :: + * Validation for hyper-parameter tuning. + * Randomly splits the input dataset into train and validation sets, + * and uses evaluation metric on the validation set to select the best model. + * Similar to [[CrossValidator]], but only splits the set once. + */ +@Experimental +class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel] + with TrainValidationSplitParams with Logging { + + def this() = this(Identifiable.randomUID("tvs")) + + /** @group setParam */ + def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + + /** @group setParam */ + def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + + /** @group setParam */ + def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + + /** @group setParam */ + def setTrainRatio(value: Double): this.type = set(trainRatio, value) + + override def fit(dataset: DataFrame): TrainValidationSplitModel = { + val schema = dataset.schema + transformSchema(schema, logging = true) + val sqlCtx = dataset.sqlContext + val est = $(estimator) + val eval = $(evaluator) + val epm = $(estimatorParamMaps) + val numModels = epm.length + val metrics = new Array[Double](epm.length) + + val Array(training, validation) = + dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) + val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() + val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() + + // multi-model training + logDebug(s"Train split with multiple sets of parameters.") + val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() + var i = 0 + while (i < numModels) { + // TODO: duplicate evaluator to take extra params from input + val metric = eval.evaluate(models(i).transform(validationDataset, epm(i))) + logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + metrics(i) += metric + i += 1 + } + validationDataset.unpersist() + + logInfo(s"Train validation split metrics: ${metrics.toSeq}") + val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + logInfo(s"Best train validation split metric: $bestMetric.") + val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + $(estimator).transformSchema(schema) + } + + override def validateParams(): Unit = { + super.validateParams() + val est = $(estimator) + for (paramMap <- $(estimatorParamMaps)) { + est.copy(paramMap).validateParams() + } + } + + override def copy(extra: ParamMap): TrainValidationSplit = { + val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit] + if (copied.isDefined(estimator)) { + copied.setEstimator(copied.getEstimator.copy(extra)) + } + if (copied.isDefined(evaluator)) { + copied.setEvaluator(copied.getEvaluator.copy(extra)) + } + copied + } +} + +/** + * :: Experimental :: + * Model from train validation split. + * + * @param uid Id. + * @param bestModel Estimator determined best model. + * @param validationMetrics Evaluated validation metrics. + */ +@Experimental +class TrainValidationSplitModel private[ml] ( + override val uid: String, + val bestModel: Model[_], + val validationMetrics: Array[Double]) + extends Model[TrainValidationSplitModel] with TrainValidationSplitParams { + + override def validateParams(): Unit = { + bestModel.validateParams() + } + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + bestModel.transform(dataset) + } + + override def transformSchema(schema: StructType): StructType = { + bestModel.transformSchema(schema) + } + + override def copy(extra: ParamMap): TrainValidationSplitModel = { + val copied = new TrainValidationSplitModel ( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + validationMetrics.clone()) + copyValues(copied, extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala new file mode 100644 index 0000000000000..8897ab0825acd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.evaluation.Evaluator +import org.apache.spark.ml.param.{ParamMap, Param, Params} + +/** + * :: DeveloperApi :: + * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]]. + */ +@DeveloperApi +private[ml] trait ValidatorParams extends Params { + + /** + * param for the estimator to be validated + * @group param + */ + val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + + /** @group getParam */ + def getEstimator: Estimator[_] = $(estimator) + + /** + * param for estimator param maps + * @group param + */ + val estimatorParamMaps: Param[Array[ParamMap]] = + new Param(this, "estimatorParamMaps", "param maps for the estimator") + + /** @group getParam */ + def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) + + /** + * param for the evaluator used to select hyper-parameters that maximize the validated metric + * @group param + */ + val evaluator: Param[Evaluator] = new Param(this, "evaluator", + "evaluator used to select hyper-parameters that maximize the validated metric") + + /** @group getParam */ + def getEvaluator: Evaluator = $(evaluator) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala new file mode 100644 index 0000000000000..c8e58f216cceb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCol +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.StructType + +class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext { + test("train validation with logistic regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new TrainValidationSplit() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] + assert(cv.getTrainRatio === 0.5) + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.validationMetrics.length === lrParamMaps.length) + } + + test("train validation with linear regression") { + val dataset = sqlContext.createDataFrame( + sc.parallelize(LinearDataGenerator.generateLinearInput( + 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) + + val trainer = new LinearRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(trainer.regParam, Array(1000.0, 0.001)) + .addGrid(trainer.maxIter, Array(0, 10)) + .build() + val eval = new RegressionEvaluator() + val cv = new TrainValidationSplit() + .setEstimator(trainer) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + val cvModel = cv.fit(dataset) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent.getRegParam === 0.001) + assert(parent.getMaxIter === 10) + assert(cvModel.validationMetrics.length === lrParamMaps.length) + + eval.setMetricName("r2") + val cvModel2 = cv.fit(dataset) + val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression] + assert(parent2.getRegParam === 0.001) + assert(parent2.getMaxIter === 10) + assert(cvModel2.validationMetrics.length === lrParamMaps.length) + } + + test("validateParams should check estimatorParamMaps") { + import TrainValidationSplitSuite._ + + val est = new MyEstimator("est") + val eval = new MyEvaluator + val paramMaps = new ParamGridBuilder() + .addGrid(est.inputCol, Array("input1", "input2")) + .build() + + val cv = new TrainValidationSplit() + .setEstimator(est) + .setEstimatorParamMaps(paramMaps) + .setEvaluator(eval) + .setTrainRatio(0.5) + cv.validateParams() // This should pass. + + val invalidParamMaps = paramMaps :+ ParamMap(est.inputCol -> "") + cv.setEstimatorParamMaps(invalidParamMaps) + intercept[IllegalArgumentException] { + cv.validateParams() + } + } +} + +object TrainValidationSplitSuite { + + abstract class MyModel extends Model[MyModel] + + class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { + + override def validateParams(): Unit = require($(inputCol).nonEmpty) + + override def fit(dataset: DataFrame): MyModel = { + throw new UnsupportedOperationException + } + + override def transformSchema(schema: StructType): StructType = { + throw new UnsupportedOperationException + } + + override def copy(extra: ParamMap): MyEstimator = defaultCopy(extra) + } + + class MyEvaluator extends Evaluator { + + override def evaluate(dataset: DataFrame): Double = { + throw new UnsupportedOperationException + } + + override val uid: String = "eval" + + override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) + } +} From d71a13f475df2d05a7db9e25738d1353cbc8cfc7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 22 Jul 2015 21:02:19 -0700 Subject: [PATCH 017/219] [SPARK-9262][build] Treat Scala compiler warnings as errors I've seen a few cases in the past few weeks that the compiler is throwing warnings that are caused by legitimate bugs. This patch upgrades warnings to errors, except deprecation warnings. Note that ideally we should be able to mark deprecation warnings as errors as well. However, due to the lack of ability to suppress individual warning messages in the Scala compiler, we cannot do that (since we do need to access deprecated APIs in Hadoop). Most of the work are done by ericl. Author: Reynold Xin Author: Eric Liang Closes #7598 from rxin/warnings and squashes the following commits: beb311b [Reynold Xin] Fixed tests. 542c031 [Reynold Xin] Fixed one more warning. 87c354a [Reynold Xin] Fixed all non-deprecation warnings. 78660ac [Eric Liang] first effort to fix warnings --- .../apache/spark/api/r/RBackendHandler.scala | 1 + .../org/apache/spark/rdd/CoGroupedRDD.scala | 7 ++-- .../org/apache/spark/util/JsonProtocol.scala | 2 ++ .../util/SerializableConfiguration.scala | 2 -- .../spark/util/SerializableJobConf.scala | 2 -- .../stat/test/KolmogorovSmirnovTest.scala | 4 +-- project/SparkBuild.scala | 33 ++++++++++++++++++- .../sql/catalyst/CatalystTypeConverters.scala | 3 +- .../apache/spark/sql/DataFrameWriter.scala | 3 ++ .../sql/execution/datasources/commands.scala | 4 ++- .../spark/sql/hive/orc/OrcFilters.scala | 6 ++-- .../sql/sources/hadoopFsRelationSuites.scala | 6 ++-- 12 files changed, 55 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 9658e9a696ffa..a5de10fe89c42 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -20,6 +20,7 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.HashMap +import scala.language.existentials import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 658e8c8b89318..130b58882d8ee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -94,13 +94,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } override def getDependencies: Seq[Dependency[_]] = { - rdds.map { rdd: RDD[_ <: Product2[K, _]] => + rdds.map { rdd: RDD[_] => if (rdd.partitioner == Some(part)) { logDebug("Adding one-to-one dependency with " + rdd) new OneToOneDependency(rdd) } else { logDebug("Adding shuffle dependency with " + rdd) - new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer) + new ShuffleDependency[K, Any, CoGroupCombiner]( + rdd.asInstanceOf[RDD[_ <: Product2[K, _]]], part, serializer) } } } @@ -133,7 +134,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // A list of (rdd iterator, dependency number) pairs val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)] for ((dep, depNum) <- dependencies.zipWithIndex) dep match { - case oneToOneDependency: OneToOneDependency[Product2[K, Any]] => + case oneToOneDependency: OneToOneDependency[Product2[K, Any]] @unchecked => val dependencyPartition = split.narrowDeps(depNum).get.split // Read them from the parent val it = oneToOneDependency.rdd.iterator(dependencyPartition, context) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index a078f14af52a1..c600319d9ddb4 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -94,6 +94,8 @@ private[spark] object JsonProtocol { logStartToJson(logStart) case metricsUpdate: SparkListenerExecutorMetricsUpdate => executorMetricsUpdateToJson(metricsUpdate) + case blockUpdated: SparkListenerBlockUpdated => + throw new MatchError(blockUpdated) // TODO(ekl) implement this } } diff --git a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala index 30bcf1d2f24d5..3354a923273ff 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableConfiguration.scala @@ -20,8 +20,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.hadoop.conf.Configuration -import org.apache.spark.util.Utils - private[spark] class SerializableConfiguration(@transient var value: Configuration) extends Serializable { private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala index afbcc6efc850c..cadae472b3f85 100644 --- a/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala +++ b/core/src/main/scala/org/apache/spark/util/SerializableJobConf.scala @@ -21,8 +21,6 @@ import java.io.{ObjectInputStream, ObjectOutputStream} import org.apache.hadoop.mapred.JobConf -import org.apache.spark.util.Utils - private[spark] class SerializableJobConf(@transient var value: JobConf) extends Serializable { private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala index d89b0059d83f3..2b3ed6df486c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat.test import scala.annotation.varargs import org.apache.commons.math3.distribution.{NormalDistribution, RealDistribution} -import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest +import org.apache.commons.math3.stat.inference.{KolmogorovSmirnovTest => CommonMathKolmogorovSmirnovTest} import org.apache.spark.Logging import org.apache.spark.rdd.RDD @@ -187,7 +187,7 @@ private[stat] object KolmogorovSmirnovTest extends Logging { } private def evalOneSampleP(ksStat: Double, n: Long): KolmogorovSmirnovTestResult = { - val pval = 1 - new KolmogorovSmirnovTest().cdf(ksStat, n.toInt) + val pval = 1 - new CommonMathKolmogorovSmirnovTest().cdf(ksStat, n.toInt) new KolmogorovSmirnovTestResult(pval, ksStat, NullHypothesis.OneSampleTwoSided.toString) } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 12828547d7077..61a05d375d99e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -154,7 +154,38 @@ object SparkBuild extends PomBuild { if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty }, - javacOptions in Compile ++= Seq("-encoding", "UTF-8") + javacOptions in Compile ++= Seq("-encoding", "UTF-8"), + + // Implements -Xfatal-warnings, ignoring deprecation warnings. + // Code snippet taken from https://issues.scala-lang.org/browse/SI-8410. + compile in Compile := { + val analysis = (compile in Compile).value + val s = streams.value + + def logProblem(l: (=> String) => Unit, f: File, p: xsbti.Problem) = { + l(f.toString + ":" + p.position.line.fold("")(_ + ":") + " " + p.message) + l(p.position.lineContent) + l("") + } + + var failed = 0 + analysis.infos.allInfos.foreach { case (k, i) => + i.reportedProblems foreach { p => + val deprecation = p.message.contains("is deprecated") + + if (!deprecation) { + failed = failed + 1 + } + + logProblem(if (deprecation) s.log.warn else s.log.error, k, p) + } + } + + if (failed > 0) { + sys.error(s"$failed fatal warnings") + } + analysis + } ) def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 8f63d2120ad0e..ae0ab2f4c63f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -24,6 +24,7 @@ import java.util.{Map => JavaMap} import javax.annotation.Nullable import scala.collection.mutable.HashMap +import scala.language.existentials import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ @@ -401,7 +402,7 @@ object CatalystTypeConverters { case seq: Seq[Any] => seq.map(convertToCatalyst) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray - case m: Map[Any, Any] => + case m: Map[_, _] => m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ee0201a9d4cb2..05da05d7b8050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -197,6 +197,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { // the table. But, insertInto with Overwrite requires the schema of data be the same // the schema of the table. insertInto(tableName) + + case SaveMode.Overwrite => + throw new UnsupportedOperationException("overwrite mode unsupported.") } } else { val cmd = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala index 84a0441e145c5..cd2aa7f7433c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala @@ -100,7 +100,7 @@ private[sql] case class InsertIntoHadoopFsRelation( val pathExists = fs.exists(qualifiedOutputPath) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => - sys.error(s"path $qualifiedOutputPath already exists.") + throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => fs.delete(qualifiedOutputPath, true) true @@ -108,6 +108,8 @@ private[sql] case class InsertIntoHadoopFsRelation( true case (SaveMode.Ignore, exists) => !exists + case (s, exists) => + throw new IllegalStateException(s"unsupported save mode $s ($exists)") } // If we are appending data to an existing dir. val isAppend = pathExists && (mode == SaveMode.Append) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 250e73a4dba92..ddd5d24717add 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -41,10 +41,10 @@ private[orc] object OrcFilters extends Logging { private def buildSearchArgument(expression: Filter, builder: Builder): Option[Builder] = { def newBuilder = SearchArgument.FACTORY.newBuilder() - def isSearchableLiteral(value: Any) = value match { + def isSearchableLiteral(value: Any): Boolean = value match { // These are types recognized by the `SearchArgumentImpl.BuilderImpl.boxLiteral()` method. - case _: String | _: Long | _: Double | _: DateWritable | _: HiveDecimal | _: HiveChar | - _: HiveVarchar | _: Byte | _: Short | _: Integer | _: Float => true + case _: String | _: Long | _: Double | _: Byte | _: Short | _: Integer | _: Float => true + case _: DateWritable | _: HiveDecimal | _: HiveChar | _: HiveVarchar => true case _ => false } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 1cef83fd5e990..2a8748d913569 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -134,7 +134,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { test("save()/load() - non-partitioned table - ErrorIfExists") { withTempDir { file => - intercept[RuntimeException] { + intercept[AnalysisException] { testDF.write.format(dataSourceName).mode(SaveMode.ErrorIfExists).save(file.getCanonicalPath) } } @@ -233,7 +233,7 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils { test("save()/load() - partitioned table - ErrorIfExists") { withTempDir { file => - intercept[RuntimeException] { + intercept[AnalysisException] { partitionedTestDF.write .format(dataSourceName) .mode(SaveMode.ErrorIfExists) @@ -696,7 +696,7 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { // This should only complain that the destination directory already exists, rather than file // "empty" is not a Parquet file. assert { - intercept[RuntimeException] { + intercept[AnalysisException] { df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) }.getMessage.contains("already exists") } From b217230f2a96c6d5a0554c593bdf1d1374878688 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 22 Jul 2015 21:04:04 -0700 Subject: [PATCH 018/219] [SPARK-9144] Remove DAGScheduler.runLocallyWithinThread and spark.localExecution.enabled Spark has an option called spark.localExecution.enabled; according to the docs: > Enables Spark to run certain jobs, such as first() or take() on the driver, without sending tasks to the cluster. This can make certain jobs execute very quickly, but may require shipping a whole partition of data to the driver. This feature ends up adding quite a bit of complexity to DAGScheduler, especially in the runLocallyWithinThread method, but as far as I know nobody uses this feature (I searched the mailing list and haven't seen any recent mentions of the configuration nor stacktraces including the runLocally method). As a step towards scheduler complexity reduction, I propose that we remove this feature and all code related to it for Spark 1.5. This pull request simply brings #7484 up to date. Author: Josh Rosen Author: Reynold Xin Closes #7585 from rxin/remove-local-exec and squashes the following commits: 84bd10e [Reynold Xin] Python fix. 1d9739a [Reynold Xin] Merge pull request #7484 from JoshRosen/remove-localexecution eec39fa [Josh Rosen] Remove allowLocal(); deprecate user-facing uses of it. b0835dc [Josh Rosen] Remove local execution code in DAGScheduler 8975d96 [Josh Rosen] Remove local execution tests. ffa8c9b [Josh Rosen] Remove documentation for configuration --- .../scala/org/apache/spark/SparkContext.scala | 86 ++++++++++--- .../apache/spark/api/java/JavaRDDLike.scala | 2 +- .../apache/spark/api/python/PythonRDD.scala | 5 +- .../org/apache/spark/executor/Executor.scala | 2 - .../apache/spark/rdd/PairRDDFunctions.scala | 2 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 4 +- .../apache/spark/rdd/ZippedWithIndexRDD.scala | 3 +- .../apache/spark/scheduler/DAGScheduler.scala | 117 +++--------------- .../spark/scheduler/DAGSchedulerEvent.scala | 1 - .../scala/org/apache/spark/rdd/RDDSuite.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 83 ++----------- .../OutputCommitCoordinatorSuite.scala | 4 +- .../spark/scheduler/SparkListenerSuite.scala | 2 +- docs/configuration.md | 9 -- .../spark/streaming/kafka/KafkaRDD.scala | 3 +- .../apache/spark/mllib/rdd/SlidingRDD.scala | 2 +- python/pyspark/context.py | 3 +- python/pyspark/rdd.py | 4 +- .../spark/sql/execution/SparkPlan.scala | 3 +- 19 files changed, 108 insertions(+), 229 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4976e5eb49468..6a6b94a271cfc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1758,16 +1758,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli /** * Run a function on a given set of partitions in an RDD and pass the results to the given - * handler function. This is the main entry point for all actions in Spark. The allowLocal - * flag specifies whether the scheduler can run the computation on the driver rather than - * shipping it out to the cluster, for short actions like first(). + * handler function. This is the main entry point for all actions in Spark. */ def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], - allowLocal: Boolean, - resultHandler: (Int, U) => Unit) { + resultHandler: (Int, U) => Unit): Unit = { if (stopped.get()) { throw new IllegalStateException("SparkContext has been shutdown") } @@ -1777,54 +1774,104 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (conf.getBoolean("spark.logLineage", false)) { logInfo("RDD's recursive dependencies:\n" + rdd.toDebugString) } - dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, - resultHandler, localProperties.get) + dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, resultHandler, localProperties.get) progressBar.foreach(_.finishAll()) rdd.doCheckpoint() } /** - * Run a function on a given set of partitions in an RDD and return the results as an array. The - * allowLocal flag specifies whether the scheduler can run the computation on the driver rather - * than shipping it out to the cluster, for short actions like first(). + * Run a function on a given set of partitions in an RDD and return the results as an array. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int]): Array[U] = { + val results = new Array[U](partitions.size) + runJob[T, U](rdd, func, partitions, (index, res) => results(index) = res) + results + } + + /** + * Run a job on a given set of partitions of an RDD, but take a function of type + * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + */ + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: Iterator[T] => U, + partitions: Seq[Int]): Array[U] = { + val cleanedFunc = clean(func) + runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions) + } + + + /** + * Run a function on a given set of partitions in an RDD and pass the results to the given + * handler function. This is the main entry point for all actions in Spark. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. + */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") + def runJob[T, U: ClassTag]( + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitions: Seq[Int], + allowLocal: Boolean, + resultHandler: (Int, U) => Unit): Unit = { + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions, resultHandler) + } + + /** + * Run a function on a given set of partitions in an RDD and return the results as an array. + * + * The allowLocal flag is deprecated as of Spark 1.5.0+. */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val results = new Array[U](partitions.size) - runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res) - results + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on a given set of partitions of an RDD, but take a function of type * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + * + * The allowLocal argument is deprecated as of Spark 1.5.0+. */ + @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") def runJob[T, U: ClassTag]( rdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int], allowLocal: Boolean ): Array[U] = { - val cleanedFunc = clean(func) - runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions, allowLocal) + if (allowLocal) { + logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") + } + runJob(rdd, func, partitions) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** * Run a job on all partitions in an RDD and return the results in an array. */ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { - runJob(rdd, func, 0 until rdd.partitions.size, false) + runJob(rdd, func, 0 until rdd.partitions.length) } /** @@ -1835,7 +1882,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli processPartition: (TaskContext, Iterator[T]) => U, resultHandler: (Int, U) => Unit) { - runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processPartition, 0 until rdd.partitions.length, resultHandler) } /** @@ -1847,7 +1894,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli resultHandler: (Int, U) => Unit) { val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) - runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) + runJob[T, U](rdd, processFunc, 0 until rdd.partitions.length, resultHandler) } /** @@ -1892,7 +1939,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (context: TaskContext, iter: Iterator[T]) => cleanF(iter), partitions, callSite, - allowLocal = false, resultHandler, localProperties.get) new SimpleFutureAction(waiter, resultFunc) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index c95615a5a9307..829fae1d1d9bf 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -364,7 +364,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { // This is useful for implementing `take` from other language frontends // like Python where the data is serialized. import scala.collection.JavaConversions._ - val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true) + val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds) res.map(x => new java.util.ArrayList(x.toSeq)).toArray } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index dc9f62f39e6d5..598953ac3bcc8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -358,12 +358,11 @@ private[spark] object PythonRDD extends Logging { def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int], - allowLocal: Boolean): Int = { + partitions: JArrayList[Int]): Int = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = - sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) serveIterator(flattenedPartition.iterator, s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 66624ffbe4790..581b40003c6c4 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -215,8 +215,6 @@ private[spark] class Executor( attemptNumber = attemptNumber, metricsSystem = env.metricsSystem) } finally { - // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; - // when changing this, make sure to update both copies. val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 91a6a2d039852..326fafb230a40 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -881,7 +881,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } buf } : Seq[V] - val res = self.context.runJob(self, process, Array(index), false) + val res = self.context.runJob(self, process, Array(index)) res(0) case None => self.filter(_._1 == key).map(_._2).collect() diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9f7ebae3e9af3..394c6686cbabd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -897,7 +897,7 @@ abstract class RDD[T: ClassTag]( */ def toLocalIterator: Iterator[T] = withScope { def collectPartition(p: Int): Array[T] = { - sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head + sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head } (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) } @@ -1273,7 +1273,7 @@ abstract class RDD[T: ClassTag]( val left = num - buf.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true) + val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) partsScanned += numPartsToTry diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index 523aaf2b860b5..e277ae28d588f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -50,8 +50,7 @@ class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, L prev.context.runJob( prev, Utils.getIteratorSize _, - 0 until n - 1, // do not need to count the last partition - allowLocal = false + 0 until n - 1 // do not need to count the last partition ).scanLeft(0L)(_ + _) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b829d06923404..552dabcfa5139 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,7 +38,6 @@ import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -128,10 +127,6 @@ class DAGScheduler( // This is only safe because DAGScheduler runs in a single thread. private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() - - /** If enabled, we may run certain actions like take() and first() locally. */ - private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) - /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) @@ -515,7 +510,6 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. @@ -535,7 +529,7 @@ class DAGScheduler( val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, + jobId, rdd, func2, partitions.toArray, callSite, waiter, SerializationUtils.clone(properties))) waiter } @@ -545,11 +539,10 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], callSite: CallSite, - allowLocal: Boolean, resultHandler: (Int, U) => Unit, properties: Properties): Unit = { val start = System.nanoTime - val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) + val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => logInfo("Job %d finished: %s, took %f s".format @@ -576,8 +569,7 @@ class DAGScheduler( val partitions = (0 until rdd.partitions.size).toArray val jobId = nextJobId.getAndIncrement() eventProcessLoop.post(JobSubmitted( - jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, - SerializationUtils.clone(properties))) + jobId, rdd, func2, partitions, callSite, listener, SerializationUtils.clone(properties))) listener.awaitResult() // Will throw an exception if the job fails } @@ -654,74 +646,6 @@ class DAGScheduler( } } - /** - * Run a job on an RDD locally, assuming it has only a single partition and no dependencies. - * We run the operation in a separate thread just in case it takes a bunch of time, so that we - * don't block the DAGScheduler event loop or other concurrent jobs. - */ - protected def runLocally(job: ActiveJob) { - logInfo("Computing the requested partition locally") - new Thread("Local computation of job " + job.jobId) { - override def run() { - runLocallyWithinThread(job) - } - }.start() - } - - // Broken out for easier testing in DAGSchedulerSuite. - protected def runLocallyWithinThread(job: ActiveJob) { - var jobResult: JobResult = JobSucceeded - try { - val rdd = job.finalStage.rdd - val split = rdd.partitions(job.partitions(0)) - val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) - val taskContext = - new TaskContextImpl( - job.finalStage.id, - job.partitions(0), - taskAttemptId = 0, - attemptNumber = 0, - taskMemoryManager = taskMemoryManager, - metricsSystem = env.metricsSystem, - runningLocally = true) - TaskContext.setTaskContext(taskContext) - try { - val result = job.func(taskContext, rdd.iterator(split, taskContext)) - job.listener.taskSucceeded(0, result) - } finally { - taskContext.markTaskCompleted() - TaskContext.unset() - // Note: this memory freeing logic is duplicated in Executor.run(); when changing this, - // make sure to update both copies. - val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() - if (freedMemory > 0) { - if (sc.getConf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { - throw new SparkException(s"Managed memory leak detected; size = $freedMemory bytes") - } else { - logError(s"Managed memory leak detected; size = $freedMemory bytes") - } - } - } - } catch { - case e: Exception => - val exception = new SparkDriverExecutionException(e) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - case oom: OutOfMemoryError => - val exception = new SparkException("Local job aborted due to out of memory error", oom) - jobResult = JobFailed(exception) - job.listener.jobFailed(exception) - } finally { - val s = job.finalStage - // clean up data structures that were populated for a local job, - // but that won't get cleaned up via the normal paths through - // completion events or stage abort - stageIdToStage -= s.id - jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult)) - } - } - /** Finds the earliest-created active job that needs the stage */ // TODO: Probably should actually find among the active jobs that need this // stage the one with the highest priority (highest-priority pool, earliest created). @@ -784,7 +708,6 @@ class DAGScheduler( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, properties: Properties) { @@ -802,29 +725,20 @@ class DAGScheduler( if (finalStage != null) { val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() - logInfo("Got job %s (%s) with %d output partitions (allowLocal=%s)".format( - job.jobId, callSite.shortForm, partitions.length, allowLocal)) + logInfo("Got job %s (%s) with %d output partitions".format( + job.jobId, callSite.shortForm, partitions.length)) logInfo("Final stage: " + finalStage + "(" + finalStage.name + ")") logInfo("Parents of final stage: " + finalStage.parents) logInfo("Missing parents: " + getMissingParentStages(finalStage)) - val shouldRunLocally = - localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 val jobSubmissionTime = clock.getTimeMillis() - if (shouldRunLocally) { - // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties)) - runLocally(job) - } else { - jobIdToActiveJob(jobId) = job - activeJobs += job - finalStage.resultOfJob = Some(job) - val stageIds = jobIdToStageIds(jobId).toArray - val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) - listenerBus.post( - SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) - submitStage(finalStage) - } + jobIdToActiveJob(jobId) = job + activeJobs += job + finalStage.resultOfJob = Some(job) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post( + SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties)) + submitStage(finalStage) } submitWaitingStages() } @@ -1486,9 +1400,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { - case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, - listener, properties) + case JobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) => + dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, callSite, listener, properties) case StageCancelled(stageId) => dagScheduler.handleStageCancellation(stageId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index a927eae2b04be..a213d419cf033 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -40,7 +40,6 @@ private[scheduler] case class JobSubmitted( finalRDD: RDD[_], func: (TaskContext, Iterator[_]) => _, partitions: Array[Int], - allowLocal: Boolean, callSite: CallSite, listener: JobListener, properties: Properties = null) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index f6da9f98ad253..5f718ea9f7be1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -679,7 +679,7 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { test("runJob on an invalid partition") { intercept[IllegalArgumentException] { - sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2), false) + sc.runJob(sc.parallelize(1 to 10, 2), {iter: Iterator[Int] => iter.size}, Seq(0, 1, 2)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3462a82c9cdd3..86dff8fb577d5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -153,9 +153,7 @@ class DAGSchedulerSuite } before { - // Enable local execution for this test - val conf = new SparkConf().set("spark.localExecution.enabled", "true") - sc = new SparkContext("local", "DAGSchedulerSuite", conf) + sc = new SparkContext("local", "DAGSchedulerSuite") sparkListener.submittedStageInfos.clear() sparkListener.successfulStages.clear() sparkListener.failedStages.clear() @@ -172,12 +170,7 @@ class DAGSchedulerSuite sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } @@ -241,10 +234,9 @@ class DAGSchedulerSuite rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, - allowLocal: Boolean = false, listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, CallSite("", ""), listener)) + runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener)) jobId } @@ -284,37 +276,6 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } - test("local job") { - val rdd = new PairOfIntsRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - Array(42 -> 0).iterator - override def getPartitions: Array[Partition] = - Array( new Partition { override def index: Int = 0 } ) - override def getPreferredLocations(split: Partition): List[String] = Nil - override def toString: String = "DAGSchedulerSuite Local RDD" - } - val jobId = scheduler.nextJobId.getAndIncrement() - runEvent( - JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) - assert(results === Map(0 -> 42)) - assertDataStructuresEmpty() - } - - test("local job oom") { - val rdd = new PairOfIntsRDD(sc, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = - throw new java.lang.OutOfMemoryError("test local job oom") - override def getPartitions = Array( new Partition { override def index = 0 } ) - override def getPreferredLocations(split: Partition) = Nil - override def toString = "DAGSchedulerSuite Local RDD" - } - val jobId = scheduler.nextJobId.getAndIncrement() - runEvent( - JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, CallSite("", ""), jobListener)) - assert(results.size == 0) - assertDataStructuresEmpty() - } - test("run trivial job w/ dependency") { val baseRdd = new MyRDD(sc, 1, Nil) val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd))) @@ -452,12 +413,7 @@ class DAGSchedulerSuite sc.listenerBus, mapOutputTracker, blockManagerMaster, - sc.env) { - override def runLocally(job: ActiveJob) { - // don't bother with the thread while unit testing - runLocallyWithinThread(job) - } - } + sc.env) dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) @@ -889,40 +845,23 @@ class DAGSchedulerSuite // Run this on executors sc.parallelize(1 to 10, 2).foreach { item => acc.add(1) } - // Run this within a local thread - sc.parallelize(1 to 10, 2).map { item => acc.add(1) }.take(1) - - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("misbehaved resultHandler should not crash DAGScheduler and SparkContext") { - val e1 = intercept[SparkDriverExecutionException] { - val rdd = sc.parallelize(1 to 10, 2) - sc.runJob[Int, Int]( - rdd, - (context: TaskContext, iter: Iterator[Int]) => iter.size, - Seq(0), - allowLocal = true, - (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) - } - assert(e1.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - - val e2 = intercept[SparkDriverExecutionException] { + val e = intercept[SparkDriverExecutionException] { val rdd = sc.parallelize(1 to 10, 2) sc.runJob[Int, Int]( rdd, (context: TaskContext, iter: Iterator[Int]) => iter.size, Seq(0, 1), - allowLocal = false, (part: Int, result: Int) => throw new DAGSchedulerSuiteDummyException) } - assert(e2.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) + assert(e.getCause.isInstanceOf[DAGSchedulerSuiteDummyException]) - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") { @@ -935,9 +874,8 @@ class DAGSchedulerSuite rdd.reduceByKey(_ + _, 1).count() } - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") { @@ -951,9 +889,8 @@ class DAGSchedulerSuite } assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName)) - // Make sure we can still run local commands as well as cluster commands. + // Make sure we can still run commands assert(sc.parallelize(1 to 10, 2).count() === 10) - assert(sc.parallelize(1 to 10, 2).first() === 1) } test("accumulator not calculated for resubmitted result stage") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index a9036da9cc93d..e5ecd4b7c2610 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -134,14 +134,14 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { test("Only one of two duplicate commit tasks should commit") { val rdd = sc.parallelize(Seq(1), 1) sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).commitSuccessfully _, - 0 until rdd.partitions.size, allowLocal = false) + 0 until rdd.partitions.size) assert(tempDir.list().size === 1) } test("If commit fails, if task is retried it should not be locked, and will succeed.") { val rdd = sc.parallelize(Seq(1), 1) sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).failFirstCommitAttempt _, - 0 until rdd.partitions.size, allowLocal = false) + 0 until rdd.partitions.size) assert(tempDir.list().size === 1) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 651295b7344c5..730535ece7878 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -188,7 +188,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match sc.addSparkListener(listener) val rdd1 = sc.parallelize(1 to 100, 4) val rdd2 = rdd1.map(_.toString) - sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true) + sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1)) sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) diff --git a/docs/configuration.md b/docs/configuration.md index fea259204ae68..200f3cd212e46 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1048,15 +1048,6 @@ Apart from these, the following properties are also available, and may be useful infinite (all available cores) on Mesos. - - spark.localExecution.enabled - false - - Enables Spark to run certain jobs, such as first() or take() on the driver, without sending - tasks to the cluster. This can make certain jobs execute very quickly, but may require - shipping a whole partition of data to the driver. - - spark.locality.wait 3s diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index c5cd2154772ac..1a9d78c0d4f59 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -98,8 +98,7 @@ class KafkaRDD[ val res = context.runJob( this, (tc: TaskContext, it: Iterator[R]) => it.take(parts(tc.partitionId)).toArray, - parts.keys.toArray, - allowLocal = true) + parts.keys.toArray) res.foreach(buf ++= _) buf.toArray } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index 35e81fcb3de0d..1facf83d806d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -72,7 +72,7 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int val w1 = windowSize - 1 // Get the first w1 items of each partition, starting from the second partition. val nextHeads = - parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true) + parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n) val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() var i = 0 var partitionIndex = 0 diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 43bde5ae41e23..eb5b0bbbdac4b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -913,8 +913,7 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions, - allowLocal) + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) def show_profiles(self): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7e788148d981c..fa8e0a0574a62 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1293,7 +1293,7 @@ def takeUpToNumLeft(iterator): taken += 1 p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts)) - res = self.context.runJob(self, takeUpToNumLeft, p, True) + res = self.context.runJob(self, takeUpToNumLeft, p) items += res partsScanned += numPartsToTry @@ -2193,7 +2193,7 @@ def lookup(self, key): values = self.filter(lambda kv: kv[0] == key).values() if self.partitioner is not None: - return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False) + return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)]) return values.collect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b0d56b7bf0b86..50c27def8ea54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -165,8 +165,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) val sc = sqlContext.sparkContext val res = - sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p, - allowLocal = false) + sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(n - buf.size)) partsScanned += numPartsToTry From 2f5cbd860e487e7339e627dd7e2c9baa5116b819 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 22 Jul 2015 21:40:23 -0700 Subject: [PATCH 019/219] [SPARK-8364] [SPARKR] Add crosstab to SparkR DataFrames Add `crosstab` to SparkR DataFrames, which takes two column names and returns a local R data.frame. This is similar to `table` in R. However, `table` in SparkR is used for loading SQL tables as DataFrames. The return type is data.frame instead table for `crosstab` to be compatible with Scala/Python. I couldn't run R tests successfully on my local. Many unit tests failed. So let's try Jenkins. Author: Xiangrui Meng Closes #7318 from mengxr/SPARK-8364 and squashes the following commits: d75e894 [Xiangrui Meng] fix tests 53f6ddd [Xiangrui Meng] fix tests f1348d6 [Xiangrui Meng] update test 47cb088 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-8364 5621262 [Xiangrui Meng] first version without test --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 28 ++++++++++++++++++++++++++++ R/pkg/R/generics.R | 4 ++++ R/pkg/inst/tests/test_sparkSQL.R | 13 +++++++++++++ 4 files changed, 46 insertions(+) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5834813319bfd..7f7a8a2e4de24 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -26,6 +26,7 @@ exportMethods("arrange", "collect", "columns", "count", + "crosstab", "describe", "distinct", "dropna", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a58433df3c8c1..06dd6b75dff3d 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1554,3 +1554,31 @@ setMethod("fillna", } dataFrame(sdf) }) + +#' crosstab +#' +#' Computes a pair-wise frequency table of the given columns. Also known as a contingency +#' table. The number of distinct values for each column should be less than 1e4. At most 1e6 +#' non-zero pair frequencies will be returned. +#' +#' @param col1 name of the first column. Distinct items will make the first item of each row. +#' @param col2 name of the second column. Distinct items will make the column names of the output. +#' @return a local R data.frame representing the contingency table. The first column of each row +#' will be the distinct values of `col1` and the column names will be the distinct values +#' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no +#' occurrences will have `null` as their counts. +#' +#' @rdname statfunctions +#' @export +#' @examples +#' \dontrun{ +#' df <- jsonFile(sqlCtx, "/path/to/file.json") +#' ct = crosstab(df, "title", "gender") +#' } +setMethod("crosstab", + signature(x = "DataFrame", col1 = "character", col2 = "character"), + function(x, col1, col2) { + statFunctions <- callJMethod(x@sdf, "stat") + sct <- callJMethod(statFunctions, "crosstab", col1, col2) + collect(dataFrame(sct)) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 39b5586f7c90e..836e0175c391f 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -59,6 +59,10 @@ setGeneric("count", function(x) { standardGeneric("count") }) # @export setGeneric("countByValue", function(x) { standardGeneric("countByValue") }) +# @rdname statfunctions +# @export +setGeneric("crosstab", function(x, col1, col2) { standardGeneric("crosstab") }) + # @rdname distinct # @export setGeneric("distinct", function(x, numPartitions = 1) { standardGeneric("distinct") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index a3039d36c9402..62fe48a5d6c7b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -987,6 +987,19 @@ test_that("fillna() on a DataFrame", { expect_identical(expected, actual) }) +test_that("crosstab() on a DataFrame", { + rdd <- lapply(parallelize(sc, 0:3), function(x) { + list(paste0("a", x %% 3), paste0("b", x %% 2)) + }) + df <- toDF(rdd, list("a", "b")) + ct <- crosstab(df, "a", "b") + ordered <- ct[order(ct$a_b),] + row.names(ordered) <- NULL + expected <- data.frame("a_b" = c("a0", "a1", "a2"), "b0" = c(1, 0, 1), "b1" = c(1, 1, 0), + stringsAsFactors = FALSE, row.names = NULL) + expect_identical(expected, ordered) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) From 410dd41cf6618b93b6daa6147d17339deeaa49ae Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 22 Jul 2015 23:27:25 -0700 Subject: [PATCH 020/219] [SPARK-9268] [ML] Removed varargs annotation from Params.setDefault taking multiple params Removed varargs annotation from Params.setDefault taking multiple params. Though varargs is technically correct, it often requires that developers do clean assembly, rather than (not clean) assembly, which is a nuisance during development. CC: mengxr Author: Joseph K. Bradley Closes #7604 from jkbradley/params-setdefault-varargs and squashes the following commits: 6016dc6 [Joseph K. Bradley] removed varargs annotation from Params.setDefault taking multiple params --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 5 ++++- .../test/java/org/apache/spark/ml/param/JavaTestParams.java | 3 --- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 824efa5ed4b28..954aa17e26a02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -476,11 +476,14 @@ trait Params extends Identifiable with Serializable { /** * Sets default values for a list of params. * + * Note: Java developers should use the single-parameter [[setDefault()]]. + * Annotating this with varargs can cause compilation failures due to a Scala compiler bug. + * See SPARK-9268. + * * @param paramPairs a list of param pairs that specify params and their default values to set * respectively. Make sure that the params are initialized before this method * gets called. */ - @varargs protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => setDefault(p.param.asInstanceOf[Param[Any]], p.value) diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 3ae09d39ef500..dc6ce8061f62b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -96,11 +96,8 @@ private void init() { new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); setDefault(myIntParam(), 1); - setDefault(myIntParam().w(1)); setDefault(myDoubleParam(), 0.5); - setDefault(myIntParam().w(1), myDoubleParam().w(0.5)); setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); - setDefault(myDoubleArrayParam().w(new double[] {1.0, 2.0})); } @Override From 825ab1e4526059a77e3278769797c4d065f48bd3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 22 Jul 2015 23:29:26 -0700 Subject: [PATCH 021/219] [SPARK-7254] [MLLIB] Run PowerIterationClustering directly on graph JIRA: https://issues.apache.org/jira/browse/SPARK-7254 Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #6054 from viirya/pic_on_graph and squashes the following commits: 8b87b81 [Liang-Chi Hsieh] Fix scala style. a22fb8b [Liang-Chi Hsieh] For comment. ef565a0 [Liang-Chi Hsieh] Fix indentation. d249aa1 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into pic_on_graph 82d7351 [Liang-Chi Hsieh] Run PowerIterationClustering directly on graph. --- .../clustering/PowerIterationClustering.scala | 46 ++++++++++++++++++ .../PowerIterationClusteringSuite.scala | 48 +++++++++++++++++++ 2 files changed, 94 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index e7a243f854e33..407e43a024a2e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -153,6 +153,27 @@ class PowerIterationClustering private[clustering] ( this } + /** + * Run the PIC algorithm on Graph. + * + * @param graph an affinity matrix represented as graph, which is the matrix A in the PIC paper. + * The similarity s,,ij,, represented as the edge between vertices (i, j) must + * be nonnegative. This is a symmetric matrix and hence s,,ij,, = s,,ji,,. For + * any (i, j) with nonzero similarity, there should be either (i, j, s,,ij,,) + * or (j, i, s,,ji,,) in the input. Tuples with i = j are ignored, because we + * assume s,,ij,, = 0.0. + * + * @return a [[PowerIterationClusteringModel]] that contains the clustering result + */ + def run(graph: Graph[Double, Double]): PowerIterationClusteringModel = { + val w = normalize(graph) + val w0 = initMode match { + case "random" => randomInit(w) + case "degree" => initDegreeVector(w) + } + pic(w0) + } + /** * Run the PIC algorithm. * @@ -212,6 +233,31 @@ object PowerIterationClustering extends Logging { @Experimental case class Assignment(id: Long, cluster: Int) + /** + * Normalizes the affinity graph (A) and returns the normalized affinity matrix (W). + */ + private[clustering] + def normalize(graph: Graph[Double, Double]): Graph[Double, Double] = { + val vD = graph.aggregateMessages[Double]( + sendMsg = ctx => { + val i = ctx.srcId + val j = ctx.dstId + val s = ctx.attr + if (s < 0.0) { + throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + } + if (s > 0.0) { + ctx.sendToSrc(s) + } + }, + mergeMsg = _ + _, + TripletFields.EdgeOnly) + GraphImpl.fromExistingRDDs(vD, graph.edges) + .mapTriplets( + e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), + TripletFields.Src) + } + /** * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 19e65f1b53ab5..189000512155f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -68,6 +68,54 @@ class PowerIterationClusteringSuite extends SparkFunSuite with MLlibTestSparkCon assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) } + test("power iteration clustering on graph") { + /* + We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for + edge (3, 4). + + 15-14 -13 -12 + | | + 4 . 3 - 2 11 + | | x | | + 5 0 - 1 10 + | | + 6 - 7 - 8 - 9 + */ + + val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), + (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge + (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), + (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) + + val edges = similarities.flatMap { case (i, j, s) => + if (i != j) { + Seq(Edge(i, j, s), Edge(j, i, s)) + } else { + None + } + } + val graph = Graph.fromEdges(sc.parallelize(edges, 2), 0.0) + + val model = new PowerIterationClustering() + .setK(2) + .run(graph) + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + model.assignments.collect().foreach { a => + predictions(a.cluster) += a.id + } + assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + + val model2 = new PowerIterationClustering() + .setK(2) + .setInitializationMode("degree") + .run(sc.parallelize(similarities, 2)) + val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) + model2.assignments.collect().foreach { a => + predictions2(a.cluster) += a.id + } + assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + } + test("normalize and powerIter") { /* Test normalize() with the following graph: From 6d0d8b406942edcf9fc97e76fb227ff1eb35ca3a Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 22 Jul 2015 23:44:08 -0700 Subject: [PATCH 022/219] [SPARK-8935] [SQL] Implement code generation for all casts JIRA: https://issues.apache.org/jira/browse/SPARK-8935 Author: Yijie Shen Closes #7365 from yjshen/cast_codegen and squashes the following commits: ef6e8b5 [Yijie Shen] getColumn and setColumn in struct cast, autounboxing in array and map eaece18 [Yijie Shen] remove null case in cast code gen fd7eba4 [Yijie Shen] resolve comments 80378a5 [Yijie Shen] the missing self cast 611d66e [Yijie Shen] Bug fix: NullType & primitive object unboxing 6d5c0fe [Yijie Shen] rebase and add Interval codegen 9424b65 [Yijie Shen] tiny style fix 4a1c801 [Yijie Shen] remove CodeHolder class, use function instead. 3f5df88 [Yijie Shen] CodeHolder for complex dataTypes c286f13 [Yijie Shen] moved all the cast code into class body 4edfd76 [Yijie Shen] [WIP] finished primitive part --- .../spark/sql/catalyst/expressions/Cast.scala | 523 ++++++++++++++++-- .../expressions/DateExpressionsSuite.scala | 36 +- 2 files changed, 508 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 3346d3c9f9e61..e66cd828481bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{Interval, UTF8String} +import scala.collection.mutable + object Cast { @@ -418,51 +420,506 @@ case class Cast(child: Expression, dataType: DataType) protected override def nullSafeEval(input: Any): Any = cast(input) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - // TODO: Add support for more data types. - (child.dataType, dataType) match { + val eval = child.gen(ctx) + val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) + eval.code + + castCode(ctx, eval.primitive, eval.isNull, ev.primitive, ev.isNull, dataType, nullSafeCast) + } + + // three function arguments are: child.primitive, result.primitive and result.isNull + // it returns the code snippets to be put in null safe evaluation region + private[this] type CastFunction = (String, String, String) => String + + private[this] def nullSafeCastFunction( + from: DataType, + to: DataType, + ctx: CodeGenContext): CastFunction = to match { + + case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;" + case StringType => castToStringCode(from, ctx) + case BinaryType => castToBinaryCode(from) + case DateType => castToDateCode(from, ctx) + case decimal: DecimalType => castToDecimalCode(from, decimal) + case TimestampType => castToTimestampCode(from, ctx) + case IntervalType => castToIntervalCode(from) + case BooleanType => castToBooleanCode(from) + case ByteType => castToByteCode(from) + case ShortType => castToShortCode(from) + case IntegerType => castToIntCode(from) + case FloatType => castToFloatCode(from) + case LongType => castToLongCode(from) + case DoubleType => castToDoubleCode(from) + + case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], array, ctx) + case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) + case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + } + + // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's + // Key and Value, Struct's field, we need to name out all the variable names involved in a cast. + private[this] def castCode(ctx: CodeGenContext, childPrim: String, childNull: String, + resultPrim: String, resultNull: String, resultType: DataType, cast: CastFunction): String = { + s""" + boolean $resultNull = $childNull; + ${ctx.javaType(resultType)} $resultPrim = ${ctx.defaultValue(resultType)}; + if (!${childNull}) { + ${cast(childPrim, resultPrim, resultNull)} + } + """ + } + + private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): CastFunction = { + from match { + case BinaryType => + (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);" + case DateType => + (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));""" + case TimestampType => + (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));""" + case _ => + (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" + } + } + + private[this] def castToBinaryCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => s"$evPrim = $c.getBytes();" + } + + private[this] def castToDateCode( + from: DataType, + ctx: CodeGenContext): CastFunction = from match { + case StringType => + val intOpt = ctx.freshName("intOpt") + (c, evPrim, evNull) => s""" + scala.Option $intOpt = + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c); + if ($intOpt.isDefined()) { + $evPrim = ((Integer) $intOpt.get()).intValue(); + } else { + $evNull = true; + } + """ + case TimestampType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L);"; + case _ => + (c, evPrim, evNull) => s"$evNull = true;" + } + + private[this] def changePrecision(d: String, decimalType: DecimalType, + evPrim: String, evNull: String): String = { + decimalType match { + case DecimalType.Unlimited => + s"$evPrim = $d;" + case DecimalType.Fixed(precision, scale) => + s""" + if ($d.changePrecision($precision, $scale)) { + $evPrim = $d; + } else { + $evNull = true; + } + """ + } + } - case (BinaryType, StringType) => - defineCodeGen (ctx, ev, c => - s"UTF8String.fromBytes($c)") + private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = { + from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + new scala.math.BigDecimal( + new java.math.BigDecimal($c.toString()))); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = null; + if ($c) { + tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1); + } else { + tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0); + } + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case DateType => + // date can't cast to decimal in Hive + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + // Note that we lose precision here. + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case DecimalType() => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone(); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case LongType => + (c, evPrim, evNull) => + s""" + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set($c); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + """ + case x: NumericType => + // All other numeric types can be represented precisely as Doubles + (c, evPrim, evNull) => + s""" + try { + org.apache.spark.sql.types.Decimal tmpDecimal = + new org.apache.spark.sql.types.Decimal().set( + scala.math.BigDecimal.valueOf((double) $c)); + ${changePrecision("tmpDecimal", target, evPrim, evNull)} + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + } + } - case (DateType, StringType) => - defineCodeGen(ctx, ev, c => - s"""UTF8String.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") + private[this] def castToTimestampCode( + from: DataType, + ctx: CodeGenContext): CastFunction = from match { + case StringType => + val longOpt = ctx.freshName("longOpt") + (c, evPrim, evNull) => + s""" + scala.Option $longOpt = + org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c); + if ($longOpt.isDefined()) { + $evPrim = ((Long) $longOpt.get()).longValue(); + } else { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;" + case _: IntegralType => + (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" + case DateType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMillis($c) * 1000;" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = ${decimalToTimestampCode(c)};" + case DoubleType => + (c, evPrim, evNull) => + s""" + if (Double.isNaN($c) || Double.isInfinite($c)) { + $evNull = true; + } else { + $evPrim = (long)($c * 1000000L); + } + """ + case FloatType => + (c, evPrim, evNull) => + s""" + if (Float.isNaN($c) || Float.isInfinite($c)) { + $evNull = true; + } else { + $evPrim = (long)($c * 1000000L); + } + """ + } - case (TimestampType, StringType) => - defineCodeGen(ctx, ev, c => - s"""UTF8String.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") + private[this] def castToIntervalCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());" + } + + private[this] def decimalToTimestampCode(d: String): String = + s"($d.toBigDecimal().bigDecimal().multiply(new java.math.BigDecimal(1000000L))).longValue()" + private[this] def longToTimeStampCode(l: String): String = s"$l * 1000L" + private[this] def timestampToIntegerCode(ts: String): String = + s"java.lang.Math.floor((double) $ts / 1000000L)" + private[this] def timestampToDoubleCode(ts: String): String = s"$ts / 1000000.0" + + private[this] def castToBooleanCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = $c != 0;" + case DateType => + // Hive would return null when cast from date to boolean + (c, evPrim, evNull) => s"$evNull = true;" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = !$c.isZero();" + case n: NumericType => + (c, evPrim, evNull) => s"$evPrim = $c != 0;" + } + + private[this] def castToByteCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Byte.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (byte) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toByte();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (byte) $c;" + } - case (_, StringType) => - defineCodeGen(ctx, ev, c => s"UTF8String.fromString(String.valueOf($c))") + private[this] def castToShortCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Short.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (short) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toShort();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (short) $c;" + } - case (StringType, IntervalType) => - defineCodeGen(ctx, ev, c => - s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())") + private[this] def castToIntCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Integer.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (int) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toInt();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (int) $c;" + } - // fallback for DecimalType, this must be before other numeric types - case (_, dt: DecimalType) => - super.genCode(ctx, ev) + private[this] def castToLongCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Long.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (long) ${timestampToIntegerCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toLong();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (long) $c;" + } - case (BooleanType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + private[this] def castToFloatCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Float.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = (float) (${timestampToDoubleCode(c)});" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toFloat();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (float) $c;" + } - case (dt: DecimalType, BooleanType) => - defineCodeGen(ctx, ev, c => s"!$c.isZero()") + private[this] def castToDoubleCode(from: DataType): CastFunction = from match { + case StringType => + (c, evPrim, evNull) => + s""" + try { + $evPrim = Double.valueOf($c.toString()); + } catch (java.lang.NumberFormatException e) { + $evNull = true; + } + """ + case BooleanType => + (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + case DateType => + (c, evPrim, evNull) => s"$evNull = true;" + case TimestampType => + (c, evPrim, evNull) => s"$evPrim = ${timestampToDoubleCode(c)};" + case DecimalType() => + (c, evPrim, evNull) => s"$evPrim = $c.toDouble();" + case x: NumericType => + (c, evPrim, evNull) => s"$evPrim = (double) $c;" + } - case (dt: NumericType, BooleanType) => - defineCodeGen(ctx, ev, c => s"$c != 0") + private[this] def castArrayCode( + from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = { + val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx) + + val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val fromElementNull = ctx.freshName("feNull") + val fromElementPrim = ctx.freshName("fePrim") + val toElementNull = ctx.freshName("teNull") + val toElementPrim = ctx.freshName("tePrim") + val size = ctx.freshName("n") + val j = ctx.freshName("j") + val result = ctx.freshName("result") + + (c, evPrim, evNull) => + s""" + final int $size = $c.size(); + final $arraySeqClass $result = new $arraySeqClass($size); + for (int $j = 0; $j < $size; $j ++) { + if ($c.apply($j) == null) { + $result.update($j, null); + } else { + boolean $fromElementNull = false; + ${ctx.javaType(from.elementType)} $fromElementPrim = + (${ctx.boxedType(from.elementType)}) $c.apply($j); + ${castCode(ctx, fromElementPrim, + fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)} + if ($toElementNull) { + $result.update($j, null); + } else { + $result.update($j, $toElementPrim); + } + } + } + $evPrim = $result; + """ + } - case (_: DecimalType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") + private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = { + val keyCast = nullSafeCastFunction(from.keyType, to.keyType, ctx) + val valueCast = nullSafeCastFunction(from.valueType, to.valueType, ctx) + + val hashMapClass = classOf[mutable.HashMap[Any, Any]].getName + val fromKeyPrim = ctx.freshName("fkp") + val fromKeyNull = ctx.freshName("fkn") + val fromValuePrim = ctx.freshName("fvp") + val fromValueNull = ctx.freshName("fvn") + val toKeyPrim = ctx.freshName("tkp") + val toKeyNull = ctx.freshName("tkn") + val toValuePrim = ctx.freshName("tvp") + val toValueNull = ctx.freshName("tvn") + val result = ctx.freshName("result") + + (c, evPrim, evNull) => + s""" + final $hashMapClass $result = new $hashMapClass(); + scala.collection.Iterator iter = $c.iterator(); + while (iter.hasNext()) { + scala.Tuple2 kv = (scala.Tuple2) iter.next(); + boolean $fromKeyNull = false; + ${ctx.javaType(from.keyType)} $fromKeyPrim = + (${ctx.boxedType(from.keyType)}) kv._1(); + ${castCode(ctx, fromKeyPrim, + fromKeyNull, toKeyPrim, toKeyNull, to.keyType, keyCast)} + + boolean $fromValueNull = kv._2() == null; + if ($fromValueNull) { + $result.put($toKeyPrim, null); + } else { + ${ctx.javaType(from.valueType)} $fromValuePrim = + (${ctx.boxedType(from.valueType)}) kv._2(); + ${castCode(ctx, fromValuePrim, + fromValueNull, toValuePrim, toValueNull, to.valueType, valueCast)} + if ($toValueNull) { + $result.put($toKeyPrim, null); + } else { + $result.put($toKeyPrim, $toValuePrim); + } + } + } + $evPrim = $result; + """ + } - case (_: NumericType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") + private[this] def castStructCode( + from: StructType, to: StructType, ctx: CodeGenContext): CastFunction = { - case other => - super.genCode(ctx, ev) + val fieldsCasts = from.fields.zip(to.fields).map { + case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx) } + val rowClass = classOf[GenericMutableRow].getName + val result = ctx.freshName("result") + val tmpRow = ctx.freshName("tmpRow") + + val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => { + val fromFieldPrim = ctx.freshName("ffp") + val fromFieldNull = ctx.freshName("ffn") + val toFieldPrim = ctx.freshName("tfp") + val toFieldNull = ctx.freshName("tfn") + val fromType = ctx.javaType(from.fields(i).dataType) + s""" + boolean $fromFieldNull = $tmpRow.isNullAt($i); + if ($fromFieldNull) { + $result.setNullAt($i); + } else { + $fromType $fromFieldPrim = + ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)}; + ${castCode(ctx, fromFieldPrim, + fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} + if ($toFieldNull) { + $result.setNullAt($i); + } else { + ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)}; + } + } + """ + } + }.mkString("\n") + + (c, evPrim, evNull) => + s""" + final $rowClass $result = new $rowClass(${fieldsCasts.size}); + final InternalRow $tmpRow = $c; + $fieldsEvalCode + $evPrim = $result.copy(); + """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f724bab4d8839..bdba6ce891386 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -39,7 +39,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -51,7 +51,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -63,7 +63,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -75,7 +75,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -87,7 +87,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), sdfDay.format(c.getTime).toInt) } } @@ -96,7 +96,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Year") { checkEvaluation(Year(Literal.create(null, DateType)), null) - checkEvaluation(Year(Cast(Literal(d), DateType)), 2015) + checkEvaluation(Year(Literal(d)), 2015) checkEvaluation(Year(Cast(Literal(sdfDate.format(d)), DateType)), 2015) checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013) @@ -106,7 +106,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(y, m, 28) (0 to 5 * 24).foreach { i => c.add(Calendar.HOUR_OF_DAY, 1) - checkEvaluation(Year(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Year(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.YEAR)) } } @@ -115,7 +115,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Quarter") { checkEvaluation(Quarter(Literal.create(null, DateType)), null) - checkEvaluation(Quarter(Cast(Literal(d), DateType)), 2) + checkEvaluation(Quarter(Literal(d)), 2) checkEvaluation(Quarter(Cast(Literal(sdfDate.format(d)), DateType)), 2) checkEvaluation(Quarter(Cast(Literal(ts), DateType)), 4) @@ -125,7 +125,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(y, m, 28, 0, 0, 0) (0 to 5 * 24).foreach { i => c.add(Calendar.HOUR_OF_DAY, 1) - checkEvaluation(Quarter(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Quarter(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.MONTH) / 3 + 1) } } @@ -134,7 +134,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Month") { checkEvaluation(Month(Literal.create(null, DateType)), null) - checkEvaluation(Month(Cast(Literal(d), DateType)), 4) + checkEvaluation(Month(Literal(d)), 4) checkEvaluation(Month(Cast(Literal(sdfDate.format(d)), DateType)), 4) checkEvaluation(Month(Cast(Literal(ts), DateType)), 11) @@ -144,7 +144,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.HOUR_OF_DAY, i) - checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.MONTH) + 1) } } @@ -156,7 +156,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.HOUR_OF_DAY, i) - checkEvaluation(Month(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.MONTH) + 1) } } @@ -166,7 +166,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Day / DayOfMonth") { checkEvaluation(DayOfMonth(Cast(Literal("2000-02-29"), DateType)), 29) checkEvaluation(DayOfMonth(Literal.create(null, DateType)), null) - checkEvaluation(DayOfMonth(Cast(Literal(d), DateType)), 8) + checkEvaluation(DayOfMonth(Literal(d)), 8) checkEvaluation(DayOfMonth(Cast(Literal(sdfDate.format(d)), DateType)), 8) checkEvaluation(DayOfMonth(Cast(Literal(ts), DateType)), 8) @@ -175,7 +175,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { c.set(y, 0, 1, 0, 0, 0) (0 to 365).foreach { d => c.add(Calendar.DATE, 1) - checkEvaluation(DayOfMonth(Cast(Literal(new Date(c.getTimeInMillis)), DateType)), + checkEvaluation(DayOfMonth(Literal(new Date(c.getTimeInMillis))), c.get(Calendar.DAY_OF_MONTH)) } } @@ -190,14 +190,14 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val c = Calendar.getInstance() (0 to 60 by 5).foreach { s => c.set(2015, 18, 3, 3, 5, s) - checkEvaluation(Second(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + checkEvaluation(Second(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.SECOND)) } } test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) - checkEvaluation(WeekOfYear(Cast(Literal(d), DateType)), 15) + checkEvaluation(WeekOfYear(Literal(d)), 15) checkEvaluation(WeekOfYear(Cast(Literal(sdfDate.format(d)), DateType)), 15) checkEvaluation(WeekOfYear(Cast(Literal(ts), DateType)), 45) checkEvaluation(WeekOfYear(Cast(Literal("2011-05-06"), DateType)), 18) @@ -223,7 +223,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 60 by 15).foreach { m => (0 to 60 by 15).foreach { s => c.set(2015, 18, 3, h, m, s) - checkEvaluation(Hour(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + checkEvaluation(Hour(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.HOUR_OF_DAY)) } } @@ -240,7 +240,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (0 to 60 by 5).foreach { m => (0 to 60 by 15).foreach { s => c.set(2015, 18, 3, 3, m, s) - checkEvaluation(Minute(Cast(Literal(new Timestamp(c.getTimeInMillis)), TimestampType)), + checkEvaluation(Minute(Literal(new Timestamp(c.getTimeInMillis))), c.get(Calendar.MINUTE)) } } From b983d493b490ca8bafe7eb988b62a250987ae353 Mon Sep 17 00:00:00 2001 From: "Perinkulam I. Ganesh" Date: Thu, 23 Jul 2015 07:46:20 +0100 Subject: [PATCH 023/219] [SPARK-8695] [CORE] [MLLIB] TreeAggregation shouldn't be triggered when it doesn't save wall-clock time. Author: Perinkulam I. Ganesh Closes #7397 from piganesh/SPARK-8695 and squashes the following commits: 041620c [Perinkulam I. Ganesh] [SPARK-8695][CORE][MLlib] TreeAggregation shouldn't be triggered when it doesn't save wall-clock time. 9ad067c [Perinkulam I. Ganesh] [SPARK-8695] [core] [WIP] TreeAggregation shouldn't be triggered for 5 partitions a6fed07 [Perinkulam I. Ganesh] [SPARK-8695] [core] [WIP] TreeAggregation shouldn't be triggered for 5 partitions --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 394c6686cbabd..6d61d227382d7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1082,7 +1082,9 @@ abstract class RDD[T: ClassTag]( val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce // the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { + + // Don't trigger TreeAggregation when it doesn't save wall-clock time + while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) { numPartitions /= scale val curNumPartitions = numPartitions partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { From ac3ae0f2be88e0b53f65342efe5fcbe67b5c2106 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 00:43:26 -0700 Subject: [PATCH 024/219] [SPARK-9266] Prevent "managed memory leak detected" exception from masking original exception When a task fails with an exception and also fails to properly clean up its managed memory, the `spark.unsafe.exceptionOnMemoryLeak` memory leak detection mechanism's exceptions will mask the original exception that caused the task to fail. We should throw the memory leak exception only if no other exception occurred. Author: Josh Rosen Closes #7603 from JoshRosen/SPARK-9266 and squashes the following commits: c268cb5 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-9266 c1f0167 [Josh Rosen] Fix the error masking problem 448eae8 [Josh Rosen] Add regression test --- .../org/apache/spark/executor/Executor.scala | 7 ++++-- .../scala/org/apache/spark/FailureSuite.scala | 25 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 581b40003c6c4..e76664f1bd7b0 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -209,16 +209,19 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() + var threwException = true val (value, accumUpdates) = try { - task.run( + val res = task.run( taskAttemptId = taskId, attemptNumber = attemptNumber, metricsSystem = env.metricsSystem) + threwException = false + res } finally { val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory() if (freedMemory > 0) { val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId" - if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) { + if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false) && !threwException) { throw new SparkException(errMsg) } else { logError(errMsg) diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index b099cd3fb7965..69cb4b44cf7ef 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -141,5 +141,30 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { FailureSuiteState.clear() } + test("managed memory leak error should not mask other failures (SPARK-9266") { + val conf = new SparkConf().set("spark.unsafe.exceptionOnMemoryLeak", "true") + sc = new SparkContext("local[1,1]", "test", conf) + + // If a task leaks memory but fails due to some other cause, then make sure that the original + // cause is preserved + val thrownDueToTaskFailure = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + throw new Exception("intentional task failure") + iter + }.count() + } + assert(thrownDueToTaskFailure.getMessage.contains("intentional task failure")) + + // If the task succeeded but memory was leaked, then the task should fail due to that leak + val thrownDueToMemoryLeak = intercept[SparkException] { + sc.parallelize(Seq(0)).mapPartitions { iter => + TaskContext.get().taskMemoryManager().allocate(128) + iter + }.count() + } + assert(thrownDueToMemoryLeak.getMessage.contains("memory leak")) + } + // TODO: Need to add tests with shuffle fetch failures. } From fb36397b3ce569d77db26df07ac339731cc07b1c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 23 Jul 2015 01:51:34 -0700 Subject: [PATCH 025/219] Revert "[SPARK-8579] [SQL] support arbitrary object in UnsafeRow" Reverts ObjectPool. As it stands, it has a few problems: 1. ObjectPool doesn't work with spilling and memory accounting. 2. I don't think in the long run the idea of an object pool is what we want to support, since it essentially goes back to unmanaged memory, and creates pressure on GC, and is hard to account for the total in memory size. 3. The ObjectPool patch removed the specialized getters for strings and binary, and as a result, actually introduced branches when reading non primitive data types. If we do want to support arbitrary user defined types in the future, I think we can just add an object array in UnsafeRow, rather than relying on indirect memory addressing through a pool. We also need to pick execution strategies that are optimized for those, rather than keeping a lot of unserialized JVM objects in memory during aggregation. This is probably the hardest thing I had to revert in Spark, due to recent patches that also change the same part of the code. Would be great to get a careful look. Author: Reynold Xin Closes #7591 from rxin/revert-object-pool and squashes the following commits: 01db0bc [Reynold Xin] Scala style. eda89fc [Reynold Xin] Fixed describe. 2967118 [Reynold Xin] Fixed accessor for JoinedRow. e3294eb [Reynold Xin] Merge branch 'master' into revert-object-pool 657855f [Reynold Xin] Temp commit. c20f2c8 [Reynold Xin] Style fix. fe37079 [Reynold Xin] Revert "[SPARK-8579] [SQL] support arbitrary object in UnsafeRow" --- project/SparkBuild.scala | 2 +- .../UnsafeFixedWidthAggregationMap.java | 150 ++++++------ .../sql/catalyst/expressions/UnsafeRow.java | 229 ++++++++---------- .../spark/sql/catalyst/util/ObjectPool.java | 78 ------ .../sql/catalyst/util/UniqueObjectPool.java | 59 ----- .../execution/UnsafeExternalRowSorter.java | 16 +- .../sql/catalyst/CatalystTypeConverters.scala | 3 +- .../spark/sql/catalyst/InternalRow.scala | 9 +- .../catalyst/expressions/BoundAttribute.scala | 2 + .../sql/catalyst/expressions/Projection.scala | 53 ++++ .../expressions/UnsafeRowConverter.scala | 42 ++-- .../expressions/codegen/CodeGenerator.scala | 9 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../plans/logical/LocalRelation.scala | 7 + .../UnsafeFixedWidthAggregationMapSuite.scala | 65 ++--- .../expressions/UnsafeRowConverterSuite.scala | 137 +++-------- .../sql/catalyst/util/ObjectPoolSuite.scala | 57 ----- .../org/apache/spark/sql/DataFrame.scala | 13 +- .../sql/execution/GeneratedAggregate.scala | 17 +- .../spark/sql/execution/LocalTableScan.scala | 2 - .../sql/execution/UnsafeRowSerializer.scala | 8 +- .../org/apache/spark/sql/UnsafeRowSuite.scala | 3 +- .../execution/UnsafeExternalSortSuite.scala | 7 +- .../execution/UnsafeRowSerializerSuite.scala | 14 +- 24 files changed, 355 insertions(+), 631 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 61a05d375d99e..b5b0adf630b9e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -543,7 +543,7 @@ object TestSettings { javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", - javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", + //javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 79d55b36dab01..2f7e84a7f59e2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -19,11 +19,9 @@ import java.util.Iterator; -import scala.Function1; - import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.util.ObjectPool; -import org.apache.spark.sql.catalyst.util.UniqueObjectPool; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -40,48 +38,26 @@ public final class UnsafeFixedWidthAggregationMap { * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the * map, we copy this buffer and use it as the value. */ - private final byte[] emptyBuffer; + private final byte[] emptyAggregationBuffer; - /** - * An empty row used by `initProjection` - */ - private static final InternalRow emptyRow = new GenericInternalRow(); + private final StructType aggregationBufferSchema; - /** - * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not. - */ - private final boolean reuseEmptyBuffer; + private final StructType groupingKeySchema; /** - * The projection used to initialize the emptyBuffer + * Encodes grouping keys as UnsafeRows. */ - private final Function1 initProjection; - - /** - * Encodes grouping keys or buffers as UnsafeRows. - */ - private final UnsafeRowConverter keyConverter; - private final UnsafeRowConverter bufferConverter; + private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; /** * A hashmap which maps from opaque bytearray keys to bytearray values. */ private final BytesToBytesMap map; - /** - * An object pool for objects that are used in grouping keys. - */ - private final UniqueObjectPool keyPool; - - /** - * An object pool for objects that are used in aggregation buffers. - */ - private final ObjectPool bufferPool; - /** * Re-used pointer to the current aggregation buffer */ - private final UnsafeRow currentBuffer = new UnsafeRow(); + private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); /** * Scratch space that is used when encoding grouping keys into UnsafeRow format. @@ -93,41 +69,69 @@ public final class UnsafeFixedWidthAggregationMap { private final boolean enablePerfMetrics; + /** + * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, + * false otherwise. + */ + public static boolean supportsGroupKeySchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + + /** + * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given + * schema, false otherwise. + */ + public static boolean supportsAggregationBufferSchema(StructType schema) { + for (StructField field: schema.fields()) { + if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + return false; + } + } + return true; + } + /** * Create a new UnsafeFixedWidthAggregationMap. * - * @param initProjection the default value for new keys (a "zero" of the agg. function) - * @param keyConverter the converter of the grouping key, used for row conversion. - * @param bufferConverter the converter of the aggregation buffer, used for row conversion. + * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) + * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. + * @param groupingKeySchema the schema of the grouping key, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( - Function1 initProjection, - UnsafeRowConverter keyConverter, - UnsafeRowConverter bufferConverter, + InternalRow emptyAggregationBuffer, + StructType aggregationBufferSchema, + StructType groupingKeySchema, TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { - this.initProjection = initProjection; - this.keyConverter = keyConverter; - this.bufferConverter = bufferConverter; - this.enablePerfMetrics = enablePerfMetrics; - + this.emptyAggregationBuffer = + convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); + this.aggregationBufferSchema = aggregationBufferSchema; + this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); + this.groupingKeySchema = groupingKeySchema; this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); - this.keyPool = new UniqueObjectPool(100); - this.bufferPool = new ObjectPool(initialCapacity); + this.enablePerfMetrics = enablePerfMetrics; + } - InternalRow initRow = initProjection.apply(emptyRow); - int emptyBufferSize = bufferConverter.getSizeRequirement(initRow); - this.emptyBuffer = new byte[emptyBufferSize]; - int writtenLength = bufferConverter.writeRow( - initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyBufferSize, - bufferPool); - assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; - // re-use the empty buffer only when there is no object saved in pool. - reuseEmptyBuffer = bufferPool.size() == 0; + /** + * Convert a Java object row into an UnsafeRow, allocating it into a new byte array. + */ + private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) { + final UnsafeRowConverter converter = new UnsafeRowConverter(schema); + final int size = converter.getSizeRequirement(javaRow); + final byte[] unsafeRow = new byte[size]; + final int writtenLength = + converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET, size); + assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; + return unsafeRow; } /** @@ -135,17 +139,16 @@ public UnsafeFixedWidthAggregationMap( * return the same object. */ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { - final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey); + final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); // Make sure that the buffer is large enough to hold the key. If it's not, grow it: if (groupingKeySize > groupingKeyConversionScratchSpace.length) { groupingKeyConversionScratchSpace = new byte[groupingKeySize]; } - final int actualGroupingKeySize = keyConverter.writeRow( + final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( groupingKey, groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize, - keyPool); + groupingKeySize); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; // Probe our map using the serialized key @@ -156,32 +159,25 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: - if (!reuseEmptyBuffer) { - // There is some objects referenced by emptyBuffer, so generate a new one - InternalRow initRow = initProjection.apply(emptyRow); - bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize, bufferPool); - } loc.putNewKey( groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize, - emptyBuffer, + emptyAggregationBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - emptyBuffer.length + emptyAggregationBuffer.length ); } // Reset the pointer to point to the value that we just stored or looked up: final MemoryLocation address = loc.getValueAddress(); - currentBuffer.pointTo( + currentAggregationBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - bufferConverter.numFields(), - loc.getValueLength(), - bufferPool + aggregationBufferSchema.length(), + loc.getValueLength() ); - return currentBuffer; + return currentAggregationBuffer; } /** @@ -217,16 +213,14 @@ public MapEntry next() { entry.key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - keyConverter.numFields(), - loc.getKeyLength(), - keyPool + groupingKeySchema.length(), + loc.getKeyLength() ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - bufferConverter.numFields(), - loc.getValueLength(), - bufferPool + aggregationBufferSchema.length(), + loc.getValueLength() ); return entry; } @@ -254,8 +248,6 @@ public void printPerfMetrics() { System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); - System.out.println("Number of unique objects in keys: " + keyPool.size()); - System.out.println("Number of objects in buffers: " + bufferPool.size()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 7f08bf7b742dc..fa1216b455a9e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -19,14 +19,19 @@ import java.io.IOException; import java.io.OutputStream; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; -import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.sql.types.DataType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.types.UTF8String; +import static org.apache.spark.sql.types.DataTypes.*; /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -40,20 +45,7 @@ * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the * base address of the row) that points to the beginning of the variable-length field, and length - * (they are combined into a long). For other objects, they are stored in a pool, the indexes of - * them are hold in the the word. - * - * In order to support fast hashing and equality checks for UnsafeRows that contain objects - * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make - * sure all the key have the same index for same object, then we can hash/compare the objects by - * hash/compare the index. - * - * For non-primitive types, the word of a field could be: - * UNION { - * [1] [offset: 31bits] [length: 31bits] // StringType - * [0] [offset: 31bits] [length: 31bits] // BinaryType - * - [index: 63bits] // StringType, Binary, index to object in pool - * } + * (they are combined into a long). * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ @@ -62,13 +54,9 @@ public final class UnsafeRow extends MutableRow { private Object baseObject; private long baseOffset; - /** A pool to hold non-primitive objects */ - private ObjectPool pool; - public Object getBaseObject() { return baseObject; } public long getBaseOffset() { return baseOffset; } public int getSizeInBytes() { return sizeInBytes; } - public ObjectPool getPool() { return pool; } /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; @@ -89,7 +77,42 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; } - public static final long OFFSET_BITS = 31L; + /** + * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) + */ + public static final Set settableFieldTypes; + + /** + * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). + */ + public static final Set readableFieldTypes; + + // TODO: support DecimalType + static { + settableFieldTypes = Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList(new DataType[] { + NullType, + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DateType, + TimestampType + }))); + + // We support get() on a superset of the types for which we support set(): + final Set _readableFieldTypes = new HashSet<>( + Arrays.asList(new DataType[]{ + StringType, + BinaryType + })); + _readableFieldTypes.addAll(settableFieldTypes); + readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); + } /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, @@ -104,17 +127,14 @@ public UnsafeRow() { } * @param baseOffset the offset within the base object * @param numFields the number of fields in this row * @param sizeInBytes the size of this row's backing data, in bytes - * @param pool the object pool to hold arbitrary objects */ - public void pointTo( - Object baseObject, long baseOffset, int numFields, int sizeInBytes, ObjectPool pool) { + public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) { assert numFields >= 0 : "numFields should >= 0"; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; this.sizeInBytes = sizeInBytes; - this.pool = pool; } private void assertIndexIsValid(int index) { @@ -137,68 +157,9 @@ private void setNotNullAt(int i) { BitSetMethods.unset(baseObject, baseOffset, i); } - /** - * Updates the column `i` as Object `value`, which cannot be primitive types. - */ @Override - public void update(int i, Object value) { - if (value == null) { - if (!isNullAt(i)) { - // remove the old value from pool - long idx = getLong(i); - if (idx <= 0) { - // this is the index of old value in pool, remove it - pool.replace((int)-idx, null); - } else { - // there will be some garbage left (UTF8String or byte[]) - } - setNullAt(i); - } - return; - } - - if (isNullAt(i)) { - // there is not an old value, put the new value into pool - int idx = pool.put(value); - setLong(i, (long)-idx); - } else { - // there is an old value, check the type, then replace it or update it - long v = getLong(i); - if (v <= 0) { - // it's the index in the pool, replace old value with new one - int idx = (int)-v; - pool.replace(idx, value); - } else { - // old value is UTF8String or byte[], try to reuse the space - boolean isString; - byte[] newBytes; - if (value instanceof UTF8String) { - newBytes = ((UTF8String) value).getBytes(); - isString = true; - } else { - newBytes = (byte[]) value; - isString = false; - } - int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); - int oldLength = (int) (v & Integer.MAX_VALUE); - if (newBytes.length <= oldLength) { - // the new value can fit in the old buffer, re-use it - PlatformDependent.copyMemory( - newBytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - baseObject, - baseOffset + offset, - newBytes.length); - long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L; - setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length); - } else { - // Cannot fit in the buffer - int idx = pool.put(value); - setLong(i, (long) -idx); - } - } - } - setNotNullAt(i); + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); } @Override @@ -256,40 +217,14 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } - /** - * Returns the object for column `i`, which should not be primitive type. - */ + @Override + public int size() { + return numFields; + } + @Override public Object get(int i) { - assertIndexIsValid(i); - if (isNullAt(i)) { - return null; - } - long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); - if (v <= 0) { - // It's an index to object in the pool. - int idx = (int)-v; - return pool.get(idx); - } else { - // The column could be StingType or BinaryType - boolean isString = (v >> (OFFSET_BITS * 2)) > 0; - int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); - int size = (int) (v & Integer.MAX_VALUE); - final byte[] bytes = new byte[size]; - // TODO(davies): Avoid the copy once we can manage the life cycle of Row well. - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size - ); - if (isString) { - return UTF8String.fromBytes(bytes); - } else { - return bytes; - } - } + throw new UnsupportedOperationException(); } @Override @@ -348,6 +283,38 @@ public double getDouble(int i) { } } + @Override + public UTF8String getUTF8String(int i) { + assertIndexIsValid(i); + return isNullAt(i) ? null : UTF8String.fromBytes(getBinary(i)); + } + + @Override + public byte[] getBinary(int i) { + if (isNullAt(i)) { + return null; + } else { + assertIndexIsValid(i); + final long offsetAndSize = getLong(i); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size + ); + return bytes; + } + } + + @Override + public String getString(int i) { + return getUTF8String(i).toString(); + } + /** * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal * byte array rather than referencing data stored in a data page. @@ -356,23 +323,17 @@ public double getDouble(int i) { */ @Override public UnsafeRow copy() { - if (pool != null) { - throw new UnsupportedOperationException( - "Copy is not supported for UnsafeRows that use object pools"); - } else { - UnsafeRow rowCopy = new UnsafeRow(); - final byte[] rowDataCopy = new byte[sizeInBytes]; - PlatformDependent.copyMemory( - baseObject, - baseOffset, - rowDataCopy, - PlatformDependent.BYTE_ARRAY_OFFSET, - sizeInBytes - ); - rowCopy.pointTo( - rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes, null); - return rowCopy; - } + UnsafeRow rowCopy = new UnsafeRow(); + final byte[] rowDataCopy = new byte[sizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + rowDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeInBytes + ); + rowCopy.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + return rowCopy; } /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java deleted file mode 100644 index 97f89a7d0b758..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util; - -/** - * A object pool stores a collection of objects in array, then they can be referenced by the - * pool plus an index. - */ -public class ObjectPool { - - /** - * An array to hold objects, which will grow as needed. - */ - private Object[] objects; - - /** - * How many objects in the pool. - */ - private int numObj; - - public ObjectPool(int capacity) { - objects = new Object[capacity]; - numObj = 0; - } - - /** - * Returns how many objects in the pool. - */ - public int size() { - return numObj; - } - - /** - * Returns the object at position `idx` in the array. - */ - public Object get(int idx) { - assert (idx < numObj); - return objects[idx]; - } - - /** - * Puts an object `obj` at the end of array, returns the index of it. - *

- * The array will grow as needed. - */ - public int put(Object obj) { - if (numObj >= objects.length) { - Object[] tmp = new Object[objects.length * 2]; - System.arraycopy(objects, 0, tmp, 0, objects.length); - objects = tmp; - } - objects[numObj++] = obj; - return numObj - 1; - } - - /** - * Replaces the object at `idx` with new one `obj`. - */ - public void replace(int idx, Object obj) { - assert (idx < numObj); - objects[idx] = obj; - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java deleted file mode 100644 index d512392dcaacc..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util; - -import java.util.HashMap; - -/** - * An unique object pool stores a collection of unique objects in it. - */ -public class UniqueObjectPool extends ObjectPool { - - /** - * A hash map from objects to their indexes in the array. - */ - private HashMap objIndex; - - public UniqueObjectPool(int capacity) { - super(capacity); - objIndex = new HashMap(); - } - - /** - * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will - * return the index of the existing one. - */ - @Override - public int put(Object obj) { - if (objIndex.containsKey(obj)) { - return objIndex.get(obj); - } else { - int idx = super.put(obj); - objIndex.put(obj, idx); - return idx; - } - } - - /** - * The objects can not be replaced. - */ - @Override - public void replace(int idx, Object obj) { - throw new UnsupportedOperationException(); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 39fd6e1bc6d13..be4ff400c4754 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; @@ -72,7 +71,7 @@ public UnsafeExternalRowSorter( sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, - new RowComparator(ordering, schema.length(), null), + new RowComparator(ordering, schema.length()), prefixComparator, 4096, sparkEnv.conf() @@ -140,8 +139,7 @@ public InternalRow next() { sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), numFields, - sortedIterator.getRecordLength(), - null); + sortedIterator.getRecordLength()); if (!hasNext()) { row.copy(); // so that we don't have dangling pointers to freed page cleanupResources(); @@ -174,27 +172,25 @@ public Iterator sort(Iterator inputIterator) throws IO * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise. */ public static boolean supportsSchema(StructType schema) { - // TODO: add spilling note to explain why we do this for now: return UnsafeProjection.canSupport(schema); } private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; - private final ObjectPool objPool; private final UnsafeRow row1 = new UnsafeRow(); private final UnsafeRow row2 = new UnsafeRow(); - public RowComparator(Ordering ordering, int numFields, ObjectPool objPool) { + public RowComparator(Ordering ordering, int numFields) { this.numFields = numFields; this.ordering = ordering; - this.objPool = objPool; } @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - row1.pointTo(baseObj1, baseOff1, numFields, -1, objPool); - row2.pointTo(baseObj2, baseOff2, numFields, -1, objPool); + // TODO: Why are the sizes -1? + row1.pointTo(baseObj1, baseOff1, numFields, -1); + row2.pointTo(baseObj2, baseOff2, numFields, -1); return ordering.compare(row1, row2); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index ae0ab2f4c63f5..4067833d5e648 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -281,7 +281,8 @@ object CatalystTypeConverters { } override def toScala(catalystValue: UTF8String): String = if (catalystValue == null) null else catalystValue.toString - override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString + override def toScalaImpl(row: InternalRow, column: Int): String = + row.getUTF8String(column).toString } private object DateConverter extends CatalystTypeConverter[Date, Date, Any] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 024973a6b9fcd..c7ec49b3d6c3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -27,11 +27,12 @@ import org.apache.spark.unsafe.types.UTF8String */ abstract class InternalRow extends Row { + def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + + def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + // This is only use for test - override def getString(i: Int): String = { - val str = getAs[UTF8String](i) - if (str != null) str.toString else null - } + override def getString(i: Int): String = getAs[UTF8String](i).toString // These expensive API should not be used internally. final override def getDecimal(i: Int): java.math.BigDecimal = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4a13b687bf4ce..6aa4930cb8587 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -46,6 +46,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case LongType | TimestampType => input.getLong(ordinal) case FloatType => input.getFloat(ordinal) case DoubleType => input.getDouble(ordinal) + case StringType => input.getUTF8String(ordinal) + case BinaryType => input.getBinary(ordinal) case _ => input.get(ordinal) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 69758e653eba0..04872fbc8b091 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.unsafe.types.UTF8String /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -177,6 +178,14 @@ class JoinedRow extends InternalRow { override def length: Int = row1.length + row2.length + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) @@ -271,6 +280,14 @@ class JoinedRow2 extends InternalRow { override def length: Int = row1.length + row2.length + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) @@ -359,6 +376,15 @@ class JoinedRow3 extends InternalRow { override def length: Int = row1.length + row2.length + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) @@ -447,6 +473,15 @@ class JoinedRow4 extends InternalRow { override def length: Int = row1.length + row2.length + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) @@ -535,6 +570,15 @@ class JoinedRow5 extends InternalRow { override def length: Int = row1.length + row2.length + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) @@ -623,6 +667,15 @@ class JoinedRow6 extends InternalRow { override def length: Int = row1.length + row2.length + override def getUTF8String(i: Int): UTF8String = { + if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + } + + override def getBinary(i: Int): Array[Byte] = { + if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + } + + override def get(i: Int): Any = if (i < row1.length) row1(i) else row2(i - row1.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 885ab091fcdf5..c47b16c0f8585 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,13 +17,15 @@ package org.apache.spark.sql.catalyst.expressions +import scala.util.Try + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String + /** * Converts Rows into UnsafeRow format. This class is NOT thread-safe. * @@ -35,8 +37,6 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { this(schema.fields.map(_.dataType)) } - def numFields: Int = fieldTypes.length - /** Re-used pointer to the unsafe row being written */ private[this] val unsafeRow = new UnsafeRow() @@ -77,9 +77,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { row: InternalRow, baseObject: Object, baseOffset: Long, - rowLengthInBytes: Int, - pool: ObjectPool): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes, pool) + rowLengthInBytes: Int): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes) if (writers.length > 0) { // zero-out the bitset @@ -94,16 +93,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { } var fieldNumber = 0 - var cursor: Int = fixedLengthSize + var appendCursor: Int = fixedLengthSize while (fieldNumber < writers.length) { if (row.isNullAt(fieldNumber)) { unsafeRow.setNullAt(fieldNumber) } else { - cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor) + appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) } fieldNumber += 1 } - cursor + appendCursor } } @@ -118,11 +117,11 @@ private abstract class UnsafeColumnWriter { * @param source the row being converted * @param target a pointer to the converted unsafe row * @param column the column to write - * @param cursor the offset from the start of the unsafe row to the end of the row; + * @param appendCursor the offset from the start of the unsafe row to the end of the row; * used for calculating where variable-length data should be written * @return the number of variable-length bytes written */ - def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int + def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int /** * Return the number of bytes that are needed to write this variable-length value. @@ -144,21 +143,19 @@ private object UnsafeColumnWriter { case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter case BinaryType => BinaryUnsafeColumnWriter - case t => ObjectUnsafeColumnWriter + case t => + throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") } } /** * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool). */ - def canEmbed(dataType: DataType): Boolean = { - forType(dataType) != ObjectUnsafeColumnWriter - } + def canEmbed(dataType: DataType): Boolean = Try(forType(dataType)).isSuccess } // ------------------------------------------------------------------------------------------------ - private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: def getSize(sourceRow: InternalRow, column: Int): Int = 0 @@ -249,8 +246,7 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { offset, numBytes ) - val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0 - target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong) + target.setLong(column, (cursor.toLong << 32) | numBytes.toLong) ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } } @@ -278,13 +274,3 @@ private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter { def getSize(value: Array[Byte]): Int = ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length) } - -private object ObjectUnsafeColumnWriter extends UnsafeColumnWriter { - override def getSize(sourceRow: InternalRow, column: Int): Int = 0 - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - val obj = source.get(column) - val idx = target.getPool.put(obj) - target.setLong(column, - idx) - 0 - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 319dcd1c04316..48225e1574600 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -105,10 +105,11 @@ class CodeGenContext { */ def getColumn(row: String, dataType: DataType, ordinal: Int): String = { val jt = javaType(dataType) - if (isPrimitiveType(jt)) { - s"$row.get${primitiveTypeName(jt)}($ordinal)" - } else { - s"($jt)$row.apply($ordinal)" + dataType match { + case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" + case StringType => s"$row.getUTF8String($ordinal)" + case BinaryType => s"$row.getBinary($ordinal)" + case _ => s"($jt)$row.apply($ordinal)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 3a8e8302b24fd..d65e5c38ebf5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -98,7 +98,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } public UnsafeRow apply(InternalRow i) { - ${allExprs} + $allExprs // additionalSize had '+' in the beginning int numBytes = $fixedSize $additionalSize; @@ -106,7 +106,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro buffer = new byte[numBytes]; } target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, - ${expressions.size}, numBytes, null); + ${expressions.size}, numBytes); int cursor = $fixedSize; $writers return target; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 1868f119f0e97..e3e7a11dba973 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} import org.apache.spark.sql.types.{StructField, StructType} @@ -28,6 +29,12 @@ object LocalRelation { new LocalRelation(StructType(output1 +: output).toAttributes) } + def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { + val schema = StructType.fromAttributes(output) + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + LocalRelation(output, data.map(converter(_).asInstanceOf[InternalRow])) + } + def fromProduct(output: Seq[Attribute], data: Seq[Product]): LocalRelation = { val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index c9667e90a0aaa..7566cb59e34ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -24,9 +24,8 @@ import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} import org.apache.spark.unsafe.types.UTF8String @@ -35,10 +34,10 @@ class UnsafeFixedWidthAggregationMapSuite with Matchers with BeforeAndAfterEach { + import UnsafeFixedWidthAggregationMap._ + private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) - private def emptyProjection: Projection = - GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)())) private def emptyAggregationBuffer: InternalRow = InternalRow(0) private var memoryManager: TaskMemoryManager = null @@ -54,11 +53,21 @@ class UnsafeFixedWidthAggregationMapSuite } } + test("supported schemas") { + assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) + assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) + + assert( + !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + assert( + !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) + } + test("empty map") { val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, memoryManager, 1024, // initial capacity false // disable perf metrics @@ -69,9 +78,9 @@ class UnsafeFixedWidthAggregationMapSuite test("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, memoryManager, 1024, // initial capacity false // disable perf metrics @@ -95,9 +104,9 @@ class UnsafeFixedWidthAggregationMapSuite test("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), + emptyAggregationBuffer, + aggBufferSchema, + groupKeySchema, memoryManager, 128, // initial capacity false // disable perf metrics @@ -112,36 +121,6 @@ class UnsafeFixedWidthAggregationMapSuite }.toSet seenKeys.size should be (groupKeys.size) seenKeys should be (groupKeys) - - map.free() - } - - test("with decimal in the key and values") { - val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil) - val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil) - val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))), - Seq(AttributeReference("price", DecimalType.Unlimited)())) - val map = new UnsafeFixedWidthAggregationMap( - emptyProjection, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggBufferSchema), - memoryManager, - 1, // initial capacity - false // disable perf metrics - ) - - (0 until 100).foreach { i => - val groupKey = InternalRow(Decimal(i % 10)) - val row = map.getAggregationBuffer(groupKey) - row.update(0, Decimal(i)) - } - val seenKeys: Set[Int] = map.iterator().asScala.map { entry => - entry.key.getAs[Decimal](0).toInt - }.toSet - seenKeys.size should be (10) - seenKeys should be ((0 until 10).toSet) - - map.free() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index dff5faf9f6ec8..8819234e78e60 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -45,12 +45,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) @@ -87,67 +86,15 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) + row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - val pool = new ObjectPool(10) unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") - assert(unsafeRow.get(2) === "World".getBytes) - - unsafeRow.update(1, UTF8String.fromString("World")) - assert(unsafeRow.getString(1) === "World") - assert(pool.size === 0) - unsafeRow.update(1, UTF8String.fromString("Hello World")) - assert(unsafeRow.getString(1) === "Hello World") - assert(pool.size === 1) - - unsafeRow.update(2, "World".getBytes) - assert(unsafeRow.get(2) === "World".getBytes) - assert(pool.size === 1) - unsafeRow.update(2, "Hello World".getBytes) - assert(unsafeRow.get(2) === "Hello World".getBytes) - assert(pool.size === 2) - - // We do not support copy() for UnsafeRows that reference ObjectPools - intercept[UnsupportedOperationException] { - unsafeRow.copy() - } - } - - test("basic conversion with primitive, decimal and array") { - val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType)) - val converter = new UnsafeRowConverter(fieldTypes) - - val row = new SpecificMutableRow(fieldTypes) - row.setLong(0, 0) - row.update(1, Decimal(1)) - row.update(2, Array(2)) - - val pool = new ObjectPool(10) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (8 * 3)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, pool) - assert(numBytesWritten === sizeRequired) - assert(pool.size === 2) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, pool) - assert(unsafeRow.getLong(0) === 0) - assert(unsafeRow.get(1) === Decimal(1)) - assert(unsafeRow.get(2) === Array(2)) - - unsafeRow.update(1, Decimal(2)) - assert(unsafeRow.get(1) === Decimal(2)) - unsafeRow.update(2, Array(3, 4)) - assert(unsafeRow.get(2) === Array(3, 4)) - assert(pool.size === 2) + assert(unsafeRow.getBinary(2) === "World".getBytes) } test("basic conversion with primitive, string, date and timestamp types") { @@ -165,25 +112,25 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired, null) + converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired, null) + buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01")) // Timestamp is represented as Long in unsafeRow DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be - (Timestamp.valueOf("2015-05-08 08:10:25")) + (Timestamp.valueOf("2015-05-08 08:10:25")) unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22"))) assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22")) unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25"))) DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be - (Timestamp.valueOf("2015-06-22 08:10:25")) + (Timestamp.valueOf("2015-06-22 08:10:25")) } test("null handling") { @@ -197,9 +144,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { FloatType, DoubleType, StringType, - BinaryType, - DecimalType.Unlimited, - ArrayType(IntegerType) + BinaryType + // DecimalType.Unlimited, + // ArrayType(IntegerType) ) val converter = new UnsafeRowConverter(fieldTypes) @@ -215,14 +162,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired, null) + sizeRequired) assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( - createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, - sizeRequired, null) - for (i <- 0 to fieldTypes.length - 1) { + createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) + for (i <- fieldTypes.indices) { assert(createdFromNull.isNullAt(i)) } assert(createdFromNull.getBoolean(1) === false) @@ -232,10 +178,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getLong(5) === 0) assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) - assert(createdFromNull.getString(8) === null) - assert(createdFromNull.get(9) === null) - assert(createdFromNull.get(10) === null) - assert(createdFromNull.get(11) === null) + assert(createdFromNull.getUTF8String(8) === null) + assert(createdFromNull.getBinary(9) === null) + // assert(createdFromNull.get(10) === null) + // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those // columns, then the serialized row representation should be identical to what we would get by @@ -252,19 +198,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setDouble(7, 700) r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) - r.update(10, Decimal(10)) - r.update(11, Array(11)) + // r.update(10, Decimal(10)) + // r.update(11, Array(11)) r } - val pool = new ObjectPool(1) val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired, pool) + sizeRequired) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, - sizeRequired, pool) + sizeRequired) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) @@ -275,14 +220,11 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) - assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) - assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) + assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.get(9)) + // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) - for (i <- 0 to fieldTypes.length - 1) { - if (i >= 8) { - setToNullAfterCreation.update(i, null) - } + for (i <- fieldTypes.indices) { setToNullAfterCreation.setNullAt(i) } // There are some garbage left in the var-length area @@ -297,10 +239,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setLong(5, 500) setToNullAfterCreation.setFloat(6, 600) setToNullAfterCreation.setDouble(7, 700) - setToNullAfterCreation.update(8, UTF8String.fromString("hello")) - setToNullAfterCreation.update(9, "world".getBytes) - setToNullAfterCreation.update(10, Decimal(10)) - setToNullAfterCreation.update(11, Array(11)) + // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) + // setToNullAfterCreation.update(9, "world".getBytes) + // setToNullAfterCreation.update(10, Decimal(10)) + // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) @@ -310,10 +252,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) - assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) - assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) - assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) + // assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } test("NaN canonicalization") { @@ -330,12 +272,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val converter = new UnsafeRowConverter(fieldTypes) val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1)) val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2)) - converter.writeRow( - row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null) - converter.writeRow( - row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null) + converter.writeRow(row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length) + converter.writeRow(row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length) assert(row1Buffer.toSeq === row2Buffer.toSeq) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala deleted file mode 100644 index 94764df4b9cdb..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.util - -import org.scalatest.Matchers - -import org.apache.spark.SparkFunSuite - -class ObjectPoolSuite extends SparkFunSuite with Matchers { - - test("pool") { - val pool = new ObjectPool(1) - assert(pool.put(1) === 0) - assert(pool.put("hello") === 1) - assert(pool.put(false) === 2) - - assert(pool.get(0) === 1) - assert(pool.get(1) === "hello") - assert(pool.get(2) === false) - assert(pool.size() === 3) - - pool.replace(1, "world") - assert(pool.get(1) === "world") - assert(pool.size() === 3) - } - - test("unique pool") { - val pool = new UniqueObjectPool(1) - assert(pool.put(1) === 0) - assert(pool.put("hello") === 1) - assert(pool.put(1) === 0) - assert(pool.put("hello") === 1) - - assert(pool.get(0) === 1) - assert(pool.get(1) === "hello") - assert(pool.size() === 2) - - intercept[UnsupportedOperationException] { - pool.replace(1, "world") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 323ff17357fda..fa942a1f8fd93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.util.Properties +import org.apache.spark.unsafe.types.UTF8String + import scala.language.implicitConversions import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -1282,7 +1284,7 @@ class DataFrame private[sql]( val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList - val ret: Seq[InternalRow] = if (outputCols.nonEmpty) { + val ret: Seq[Row] = if (outputCols.nonEmpty) { val aggExprs = statistics.flatMap { case (_, colToAgg) => outputCols.map(c => Column(Cast(colToAgg(Column(c).expr), StringType)).as(c)) } @@ -1290,19 +1292,18 @@ class DataFrame private[sql]( val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq // Pivot the data so each summary is one row - row.grouped(outputCols.size).toSeq.zip(statistics).map { - case (aggregation, (statistic, _)) => - InternalRow(statistic :: aggregation.toList: _*) + row.grouped(outputCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) => + Row(statistic :: aggregation.toList: _*) } } else { // If there are no output columns, just output a single column that contains the stats. - statistics.map { case (name, _) => InternalRow(name) } + statistics.map { case (name, _) => Row(name) } } // All columns are string type val schema = StructType( StructField("summary", StringType) :: outputCols.map(StructField(_, StringType))).toAttributes - LocalRelation(schema, ret) + LocalRelation.fromExternalRows(schema, ret) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 0e63f2fe29cb3..16176abe3a51d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -239,6 +239,11 @@ case class GeneratedAggregate( StructType(fields) } + val schemaSupportsUnsafe: Boolean = { + UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && + UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + } + child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -290,13 +295,14 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) - } else if (unsafeEnabled) { + + } else if (unsafeEnabled && schemaSupportsUnsafe) { assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer, - new UnsafeRowConverter(groupKeySchema), - new UnsafeRowConverter(aggregationBufferSchema), + newAggregationBuffer(EmptyRow), + aggregationBufferSchema, + groupKeySchema, TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity false // disable tracking of performance metrics @@ -331,6 +337,9 @@ case class GeneratedAggregate( } } } else { + if (unsafeEnabled) { + log.info("Not using Unsafe-based aggregator because it is not supported for this schema") + } val buffers = new java.util.HashMap[InternalRow, MutableRow]() var currentRow: InternalRow = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index cd341180b6100..34e926e4582be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -34,13 +34,11 @@ private[sql] case class LocalTableScan( protected override def doExecute(): RDD[InternalRow] = rdd - override def executeCollect(): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]).toArray } - override def executeTake(limit: Int): Array[Row] = { val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]).take(limit).toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 318550e5ed899..16498da080c88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -37,9 +37,6 @@ import org.apache.spark.unsafe.PlatformDependent * Note that this serializer implements only the [[Serializer]] methods that are used during * shuffle, so certain [[SerializerInstance]] methods will throw UnsupportedOperationException. * - * This serializer does not support UnsafeRows that use - * [[org.apache.spark.sql.catalyst.util.ObjectPool]]. - * * @param numFields the number of fields in the row being serialized. */ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with Serializable { @@ -65,7 +62,6 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def writeValue[T: ClassTag](value: T): SerializationStream = { val row = value.asInstanceOf[UnsafeRow] - assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool") dOut.writeInt(row.getSizeInBytes) row.writeToStream(out, writeBuffer) this @@ -118,7 +114,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(in, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) rowSize = dIn.readInt() // read the next row's size if (rowSize == EOF) { // We are returning the last row in this stream val _rowTuple = rowTuple @@ -152,7 +148,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(in, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index d36e2639376e7..ad3bb1744cb3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -53,8 +53,7 @@ class UnsafeRowSuite extends SparkFunSuite { offheapRowPage.getBaseObject, offheapRowPage.getBaseOffset, 3, // num fields - arrayBackedUnsafeRow.getSizeInBytes, - null // object pool + arrayBackedUnsafeRow.getSizeInBytes ) assert(offheapUnsafeRow.getBaseObject === null) val baos = new ByteArrayOutputStream() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 5fe73f7e0b072..7a4baa9e4a49d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -39,7 +39,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { ignore("sort followed by limit should not leak memory") { // TODO: this test is going to fail until we implement a proper iterator interface // with a close() method. - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), @@ -58,7 +58,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { sortAnswers = false ) } finally { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true") + TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") } } @@ -91,7 +91,8 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23), + plan => ConvertToSafe( + UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index bd788ec8c14b1..a1e1695717e23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -23,29 +23,25 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter} -import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent class UnsafeRowSerializerSuite extends SparkFunSuite { - private def toUnsafeRow( - row: Row, - schema: Array[DataType], - objPool: ObjectPool = null): UnsafeRow = { + private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] val rowConverter = new UnsafeRowConverter(schema) val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow) val byteArray = new Array[Byte](rowSizeInBytes) rowConverter.writeRow( - internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes, objPool) + internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes, objPool) + unsafeRow.pointTo(byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes) unsafeRow } - test("toUnsafeRow() test helper method") { + ignore("toUnsafeRow() test helper method") { + // This currently doesnt work because the generic getter throws an exception. val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) assert(row.getString(0) === unsafeRow.get(0).toString) From 26ed22aec8af42c6dc161e0a2827a4235a49a9a4 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Thu, 23 Jul 2015 12:43:54 +0100 Subject: [PATCH 026/219] [SPARK-9212] [CORE] upgrade Netty version to 4.0.29.Final related JIRA: [SPARK-9212](https://issues.apache.org/jira/browse/SPARK-9212) and [SPARK-8101](https://issues.apache.org/jira/browse/SPARK-8101) Author: Zhang, Liye Closes #7562 from liyezhang556520/SPARK-9212 and squashes the following commits: 1917729 [Zhang, Liye] SPARK-9212 upgrade Netty version to 4.0.29.Final --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 1f44dc8abe1d4..35fc8c44bc1b0 100644 --- a/pom.xml +++ b/pom.xml @@ -573,7 +573,7 @@ io.netty netty-all - 4.0.28.Final + 4.0.29.Final org.apache.derby From 52ef76de219c4bf19c54c99414b89a67d0bf457b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Jul 2015 09:37:53 -0700 Subject: [PATCH 027/219] [SPARK-9082] [SQL] [FOLLOW-UP] use `partition` in `PushPredicateThroughProject` a follow up of https://github.com/apache/spark/pull/7446 Author: Wenchen Fan Closes #7607 from cloud-fan/tmp and squashes the following commits: 7106989 [Wenchen Fan] use `partition` in `PushPredicateThroughProject` --- .../sql/catalyst/optimizer/Optimizer.scala | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d2db3dd3d078e..b59f800e7cc0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -553,33 +553,27 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe // Split the condition into small conditions by `And`, so that we can push down part of this // condition without nondeterministic expressions. val andConditions = splitConjunctivePredicates(condition) - val nondeterministicConditions = andConditions.filter(hasNondeterministic(_, aliasMap)) + + val (deterministic, nondeterministic) = andConditions.partition(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a) + }.forall(_.deterministic)) // If there is no nondeterministic conditions, push down the whole condition. - if (nondeterministicConditions.isEmpty) { + if (nondeterministic.isEmpty) { project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) } else { // If they are all nondeterministic conditions, leave it un-changed. - if (nondeterministicConditions.length == andConditions.length) { + if (deterministic.isEmpty) { filter } else { - val deterministicConditions = andConditions.filterNot(hasNondeterministic(_, aliasMap)) // Push down the small conditions without nondeterministic expressions. - val pushedCondition = deterministicConditions.map(replaceAlias(_, aliasMap)).reduce(And) - Filter(nondeterministicConditions.reduce(And), + val pushedCondition = deterministic.map(replaceAlias(_, aliasMap)).reduce(And) + Filter(nondeterministic.reduce(And), project.copy(child = Filter(pushedCondition, grandChild))) } } } - private def hasNondeterministic( - condition: Expression, - sourceAliases: AttributeMap[Expression]) = { - condition.collect { - case a: Attribute if sourceAliases.contains(a) => sourceAliases(a) - }.exists(!_.deterministic) - } - // Substitute any attributes that are produced by the child projection, so that we safely // eliminate it. private def replaceAlias(condition: Expression, sourceAliases: AttributeMap[Expression]) = { From 19aeab57c1b0c739edb5ba351f98e930e1a0f984 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 23 Jul 2015 10:28:20 -0700 Subject: [PATCH 028/219] [Build][Minor] Fix building error & performance 1. When build the latest code with sbt, it throws exception like: [error] /home/hcheng/git/catalyst/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala:78: match may not be exhaustive. [error] It would fail on the following input: UNKNOWN [error] val classNameByStatus = status match { [error] 2. Potential performance issue when implicitly convert an Array[Any] to Seq[Any] Author: Cheng Hao Closes #7611 from chenghao-intel/toseq and squashes the following commits: cab75c5 [Cheng Hao] remove the toArray 24df682 [Cheng Hao] fix building error & performance --- core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 1 + .../org/apache/spark/sql/catalyst/CatalystTypeConverters.scala | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 2ce670ad02e97..e72547df7254b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -79,6 +79,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { case JobExecutionStatus.SUCCEEDED => "succeeded" case JobExecutionStatus.FAILED => "failed" case JobExecutionStatus.RUNNING => "running" + case JobExecutionStatus.UNKNOWN => "unknown" } // The timeline library treats contents as HTML, so we have to escape them; for the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 4067833d5e648..bfaee04f33b7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -402,7 +402,7 @@ object CatalystTypeConverters { case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) case seq: Seq[Any] => seq.map(convertToCatalyst) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) - case arr: Array[Any] => arr.toSeq.map(convertToCatalyst).toArray + case arr: Array[Any] => arr.map(convertToCatalyst) case m: Map[_, _] => m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap case other => other From d2666a3c70dad037776dc4015fa561356381357b Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 23 Jul 2015 10:31:12 -0700 Subject: [PATCH 029/219] [SPARK-9183] confusing error message when looking up missing function in Spark SQL JIRA: https://issues.apache.org/jira/browse/SPARK-9183 cc rxin Author: Yijie Shen Closes #7613 from yjshen/npe_udf and squashes the following commits: 44f58f2 [Yijie Shen] add jira ticket number 903c963 [Yijie Shen] add explanation comments f44dd3c [Yijie Shen] Change two hive class LogLevel to avoid annoying messages --- conf/log4j.properties.template | 4 ++++ .../resources/org/apache/spark/log4j-defaults-repl.properties | 4 ++++ .../main/resources/org/apache/spark/log4j-defaults.properties | 4 ++++ 3 files changed, 12 insertions(+) diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 3a2a88219818f..27006e45e932b 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties index b146f8a784127..689afea64f8db 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults-repl.properties @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 3a2a88219818f..27006e45e932b 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -10,3 +10,7 @@ log4j.logger.org.spark-project.jetty=WARN log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO + +# SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL +log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR From ecfb3127670c7f15e3a15e7f51fa578532480cda Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 23 Jul 2015 10:32:11 -0700 Subject: [PATCH 030/219] [SPARK-9243] [Documentation] null -> zero in crosstab doc We forgot to update doc. brkyvz Author: Xiangrui Meng Closes #7608 from mengxr/SPARK-9243 and squashes the following commits: 0ea3236 [Xiangrui Meng] null -> zero in crosstab doc --- R/pkg/R/DataFrame.R | 2 +- python/pyspark/sql/dataframe.py | 2 +- .../scala/org/apache/spark/sql/DataFrameStatFunctions.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 06dd6b75dff3d..f4c93d3c7dd67 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1566,7 +1566,7 @@ setMethod("fillna", #' @return a local R data.frame representing the contingency table. The first column of each row #' will be the distinct values of `col1` and the column names will be the distinct values #' of `col2`. The name of the first column will be `$col1_$col2`. Pairs that have no -#' occurrences will have `null` as their counts. +#' occurrences will have zero as their counts. #' #' @rdname statfunctions #' @export diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 83e02b85f06f1..d76e051bd73a1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1130,7 +1130,7 @@ def crosstab(self, col1, col2): non-zero pair frequencies will be returned. The first column of each row will be the distinct values of `col1` and the column names will be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. - Pairs that have no occurrences will have `null` as their counts. + Pairs that have no occurrences will have zero as their counts. :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases. :param col1: The name of the first column. Distinct items will make the first item of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 587869e57f96e..4ec58082e7aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -77,7 +77,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * pair frequencies will be returned. * The first column of each row will be the distinct values of `col1` and the column names will * be the distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts - * will be returned as `Long`s. Pairs that have no occurrences will have `null` as their counts. + * will be returned as `Long`s. Pairs that have no occurrences will have zero as their counts. * Null elements will be replaced by "null", and back ticks will be dropped from elements if they * exist. * From 662d60db3f4a758b6869de5bd971d23bd5962c3b Mon Sep 17 00:00:00 2001 From: David Arroyo Cazorla Date: Thu, 23 Jul 2015 10:34:32 -0700 Subject: [PATCH 031/219] [SPARK-5447][SQL] Replace reference 'schema rdd' with DataFrame @rxin. Author: David Arroyo Cazorla Closes #7618 from darroyocazorla/master and squashes the following commits: 5f91379 [David Arroyo Cazorla] [SPARK-5447][SQL] Replace reference 'schema rdd' with DataFrame --- .../scala/org/apache/spark/sql/execution/CacheManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index a4b38d364d54a..d3e5c378d037d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -84,7 +84,7 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { } /** - * Caches the data produced by the logical representation of the given schema rdd. Unlike + * Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike * `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing * the in-memory columnar representation of the underlying table is expensive. */ From b2f3aca1e8c182b93e250f9d9c4aa69f97eaa11a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 16:08:07 -0700 Subject: [PATCH 032/219] [SPARK-9286] [SQL] Methods in Unevaluable should be final and AlgebraicAggregate should extend Unevaluable. This patch marks the Unevaluable.eval() and UnevaluablegenCode() methods as final and fixes two cases where they were overridden. It also updates AggregateFunction2 to extend Unevaluable. Author: Josh Rosen Closes #7627 from JoshRosen/unevaluable-fix and squashes the following commits: 8d9ed22 [Josh Rosen] AlgebraicAggregate should extend Unevaluable 65329c2 [Josh Rosen] Do not have AggregateFunction1 inherit from AggregateExpression1 fa68a22 [Josh Rosen] Make eval() and genCode() final --- .../sql/catalyst/expressions/Expression.scala | 4 ++-- .../expressions/aggregate/interfaces.scala | 15 +++------------ .../sql/catalyst/expressions/aggregates.scala | 11 +++++------ 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 29ae47e842ddb..3f72e6e184db1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -184,10 +184,10 @@ abstract class Expression extends TreeNode[Expression] { */ trait Unevaluable extends Expression { - override def eval(input: InternalRow = null): Any = + final override def eval(input: InternalRow = null): Any = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = + final override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 577ede73cb01f..d3fee1ade05e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -63,10 +63,6 @@ private[sql] case object Complete extends AggregateMode */ private[sql] case object NoOp extends Expression with Unevaluable { override def nullable: Boolean = true - override def eval(input: InternalRow): Any = { - throw new TreeNodeException( - this, s"No function to evaluate expression. type: ${this.nodeName}") - } override def dataType: DataType = NullType override def children: Seq[Expression] = Nil } @@ -151,8 +147,7 @@ abstract class AggregateFunction2 /** * A helper class for aggregate functions that can be implemented in terms of catalyst expressions. */ -abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable { - self: Product => +abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable with Unevaluable { val initialValues: Seq[Expression] val updateExpressions: Seq[Expression] @@ -188,19 +183,15 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable { } } - override def update(buffer: MutableRow, input: InternalRow): Unit = { + override final def update(buffer: MutableRow, input: InternalRow): Unit = { throw new UnsupportedOperationException( "AlgebraicAggregate's update should not be called directly") } - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + override final def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { throw new UnsupportedOperationException( "AlgebraicAggregate's merge should not be called directly") } - override def eval(buffer: InternalRow): Any = { - throw new UnsupportedOperationException( - "AlgebraicAggregate's eval should not be called directly") - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index e07c920a41d0a..d3295b8bafa80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -71,8 +71,7 @@ trait PartialAggregate1 extends AggregateExpression1 { * A specific implementation of an aggregate function. Used to wrap a generic * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. */ -abstract class AggregateFunction1 - extends LeafExpression with AggregateExpression1 with Serializable { +abstract class AggregateFunction1 extends LeafExpression with Serializable { /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression1 @@ -82,9 +81,9 @@ abstract class AggregateFunction1 def update(input: InternalRow): Unit - // Do we really need this? - override def newInstance(): AggregateFunction1 = { - makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + throw new UnsupportedOperationException( + "AggregateFunction1 should not be used for generated aggregates") } } From bebe3f7b45f7b0a96f20d5af9b80633fd40cff06 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 23 Jul 2015 17:49:33 -0700 Subject: [PATCH 033/219] [SPARK-9207] [SQL] Enables Parquet filter push-down by default PARQUET-136 and PARQUET-173 have been fixed in parquet-mr 1.7.0. It's time to enable filter push-down by default now. Author: Cheng Lian Closes #7612 from liancheng/spark-9207 and squashes the following commits: 77e6b5e [Cheng Lian] Enables Parquet filter push-down by default --- docs/sql-programming-guide.md | 9 ++------- .../src/main/scala/org/apache/spark/sql/SQLConf.scala | 8 ++------ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 5838bc172fe86..95945eb7fc8a0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1332,13 +1332,8 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` spark.sql.parquet.filterPushdown - false - - Turn on Parquet filter pushdown optimization. This feature is turned off by default because of a known - bug in Parquet 1.6.0rc3 (PARQUET-136). - However, if your table doesn't contain any nullable string or binary columns, it's still safe to turn - this feature on. - + true + Enables Parquet filter push-down optimization when set to true. spark.sql.hive.convertMetastoreParquet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 1474b170ba896..2a641b9d64a95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -273,12 +273,8 @@ private[spark] object SQLConf { "uncompressed, snappy, gzip, lzo.") val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown", - defaultValue = Some(false), - doc = "Turn on Parquet filter pushdown optimization. This feature is turned off by default " + - "because of a known bug in Parquet 1.6.0rc3 " + - "(PARQUET-136, https://issues.apache.org/jira/browse/PARQUET-136). However, " + - "if your table doesn't contain any nullable string or binary columns, it's still safe to " + - "turn this feature on.") + defaultValue = Some(true), + doc = "Enables Parquet filter push-down optimization when set to true.") val PARQUET_USE_DATA_SOURCE_API = booleanConf("spark.sql.parquet.useDataSourceApi", defaultValue = Some(true), From 8a94eb23d53e291441e3144a1b800fe054457040 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 23 Jul 2015 18:31:13 -0700 Subject: [PATCH 034/219] [SPARK-9069] [SPARK-9264] [SQL] remove unlimited precision support for DecimalType Romove Decimal.Unlimited (change to support precision up to 38, to match with Hive and other databases). In order to keep backward source compatibility, Decimal.Unlimited is still there, but change to Decimal(38, 18). If no precision and scale is provide, it's Decimal(10, 0) as before. Author: Davies Liu Closes #7605 from davies/decimal_unlimited and squashes the following commits: aa3f115 [Davies Liu] fix tests and style fb0d20d [Davies Liu] address comments bfaae35 [Davies Liu] fix style df93657 [Davies Liu] address comments and clean up 06727fd [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_unlimited 4c28969 [Davies Liu] fix tests 8d783cc [Davies Liu] fix tests 788631c [Davies Liu] fix double with decimal in Union/except 1779bde [Davies Liu] fix scala style c9c7c78 [Davies Liu] remove Decimal.Unlimited --- .../spark/ml/attribute/AttributeSuite.scala | 2 +- python/pyspark/sql/types.py | 36 +-- .../org/apache/spark/sql/types/DataTypes.java | 8 +- .../sql/catalyst/JavaTypeInference.scala | 2 +- .../spark/sql/catalyst/ScalaReflection.scala | 10 +- .../apache/spark/sql/catalyst/SqlParser.scala | 5 +- .../catalyst/analysis/HiveTypeCoercion.scala | 255 +++++++----------- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 7 +- .../expressions/aggregate/functions.scala | 24 +- .../sql/catalyst/expressions/aggregates.scala | 46 ++-- .../sql/catalyst/expressions/arithmetic.scala | 17 +- .../sql/catalyst/expressions/literals.scala | 6 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../org/apache/spark/sql/types/DataType.scala | 4 +- .../spark/sql/types/DataTypeParser.scala | 2 +- .../apache/spark/sql/types/DecimalType.scala | 110 +++++--- .../spark/sql/RandomDataGenerator.scala | 4 +- .../spark/sql/RandomDataGeneratorSuite.scala | 4 +- .../sql/catalyst/ScalaReflectionSuite.scala | 14 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 6 +- .../analysis/DecimalPrecisionSuite.scala | 54 ++-- .../analysis/HiveTypeCoercionSuite.scala | 45 ++-- .../sql/catalyst/expressions/CastSuite.scala | 46 ++-- .../ConditionalExpressionSuite.scala | 2 +- .../expressions/LiteralExpressionSuite.scala | 2 +- .../expressions/NullFunctionsSuite.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 2 + .../expressions/UnsafeRowConverterSuite.scala | 2 +- .../spark/sql/types/DataTypeParserSuite.scala | 4 +- .../spark/sql/types/DataTypeSuite.scala | 4 +- .../spark/sql/types/DataTypeTestUtils.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 2 +- .../spark/sql/columnar/ColumnType.scala | 2 +- .../sql/execution/GeneratedAggregate.scala | 10 +- .../datasources/PartitioningUtils.scala | 7 +- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 22 +- .../org/apache/spark/sql/jdbc/jdbc.scala | 7 +- .../apache/spark/sql/json/InferSchema.scala | 11 +- .../sql/parquet/CatalystSchemaConverter.scala | 4 - .../sql/parquet/ParquetTableSupport.scala | 8 +- .../spark/sql/JavaApplySchemaSuite.java | 14 +- .../columnar/InMemoryColumnarQuerySuite.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../execution/SparkSqlSerializer2Suite.scala | 4 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 20 +- .../org/apache/spark/sql/json/JsonSuite.scala | 57 ++-- .../spark/sql/parquet/ParquetIOSuite.scala | 8 - .../ParquetPartitionDiscoverySuite.scala | 2 +- .../spark/sql/sources/DDLTestSuite.scala | 2 +- .../spark/sql/sources/TableScanSuite.scala | 2 +- .../spark/sql/hive/HiveInspectors.scala | 9 +- .../org/apache/spark/sql/hive/HiveQl.scala | 4 +- 53 files changed, 459 insertions(+), 473 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala index c5fd2f9d5a22a..6355e0f179496 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala @@ -218,7 +218,7 @@ class AttributeSuite extends SparkFunSuite { // Attribute.fromStructField should accept any NumericType, not just DoubleType val longFldWithMeta = new StructField("x", LongType, false, metadata) assert(Attribute.fromStructField(longFldWithMeta).isNumeric) - val decimalFldWithMeta = new StructField("x", DecimalType(None), false, metadata) + val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata) assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric) } } diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 10ad89ea14a8d..b97d50c945f24 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -194,30 +194,33 @@ def fromInternal(self, ts): class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. + + The DecimalType must have fixed precision (the maximum total number of digits) + and scale (the number of digits on the right of dot). For example, (5, 2) can + support the value from [-999.99 to 999.99]. + + The precision can be up to 38, the scale must less or equal to precision. + + When create a DecimalType, the default precision and scale is (10, 0). When infer + schema from decimal.Decimal objects, it will be DecimalType(38, 18). + + :param precision: the maximum total number of digits (default: 10) + :param scale: the number of digits on right side of dot. (default: 0) """ - def __init__(self, precision=None, scale=None): + def __init__(self, precision=10, scale=0): self.precision = precision self.scale = scale - self.hasPrecisionInfo = precision is not None + self.hasPrecisionInfo = True # this is public API def simpleString(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal(10,0)" + return "decimal(%d,%d)" % (self.precision, self.scale) def jsonValue(self): - if self.hasPrecisionInfo: - return "decimal(%d,%d)" % (self.precision, self.scale) - else: - return "decimal" + return "decimal(%d,%d)" % (self.precision, self.scale) def __repr__(self): - if self.hasPrecisionInfo: - return "DecimalType(%d,%d)" % (self.precision, self.scale) - else: - return "DecimalType()" + return "DecimalType(%d,%d)" % (self.precision, self.scale) class DoubleType(FractionalType): @@ -761,7 +764,10 @@ def _infer_type(obj): return obj.__UDT__ dataType = _type_mappings.get(type(obj)) - if dataType is not None: + if dataType is DecimalType: + # the precision and scale of `obj` may be different from row to row. + return DecimalType(38, 18) + elif dataType is not None: return dataType() if isinstance(obj, dict): diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index d22ad6794d608..5703de42393de 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -111,12 +111,18 @@ public static ArrayType createArrayType(DataType elementType, boolean containsNu return new ArrayType(elementType, containsNull); } + /** + * Creates a DecimalType by specifying the precision and scale. + */ public static DecimalType createDecimalType(int precision, int scale) { return DecimalType$.MODULE$.apply(precision, scale); } + /** + * Creates a DecimalType with default precision and scale, which are 10 and 0. + */ public static DecimalType createDecimalType() { - return DecimalType$.MODULE$.Unlimited(); + return DecimalType$.MODULE$.USER_DEFAULT(); } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 9a3f9694e4c48..88a457f87ce4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -75,7 +75,7 @@ private [sql] object JavaTypeInference { case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true) case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true) - case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true) + case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 21b1de1ab9cb1..2442341da106d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -131,10 +131,10 @@ trait ScalaReflection { case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) - case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.math.BigDecimal] => - Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true) @@ -167,8 +167,8 @@ trait ScalaReflection { case obj: Float => FloatType case obj: Double => DoubleType case obj: java.sql.Date => DateType - case obj: java.math.BigDecimal => DecimalType.Unlimited - case obj: Decimal => DecimalType.Unlimited + case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT + case obj: Decimal => DecimalType.SYSTEM_DEFAULT case obj: java.sql.Timestamp => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 29cfc064da89a..c494e5d704213 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -322,7 +322,10 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val numericLiteral: Parser[Literal] = ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } - | sign.? ~ unsignedFloat ^^ { case s ~ f => Literal((s.getOrElse("") + f).toDouble) } + | sign.? ~ unsignedFloat ^^ { + // TODO(davies): some precisions may loss, we should create decimal literal + case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue()) + } ) protected lazy val unsignedFloat: Parser[String] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e214545726249..d56ceeadc9e85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -58,8 +60,7 @@ object HiveTypeCoercion { IntegerType, LongType, FloatType, - DoubleType, - DecimalType.Unlimited) + DoubleType) /** * Find the tightest common type of two types that might be used in a binary expression. @@ -72,15 +73,16 @@ object HiveTypeCoercion { case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) - // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => + Some(t2) + case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => + Some(t1) + + // Promote numeric types to the highest of the two case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2) Some(numericPrecedence(index)) - // Fixed-precision decimals can up-cast into unlimited - case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited) - case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited) - case _ => None } @@ -101,7 +103,7 @@ object HiveTypeCoercion { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case None => None case Some(d) => - findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c)) + findTightestCommonTypeToString(d, c) }) } @@ -158,6 +160,9 @@ object HiveTypeCoercion { * converted to DOUBLE. * - TINYINT, SMALLINT, and INT can all be converted to FLOAT. * - BOOLEAN types cannot be converted to any other type. + * - Any integral numeric type can be implicitly converted to decimal type. + * - two different decimal types will be converted into a wider decimal type for both of them. + * - decimal type will be converted into double if there float or double together with it. * * Additionally, all types when UNION-ed with strings will be promoted to strings. * Other string conversions are handled by PromoteStrings. @@ -166,55 +171,50 @@ object HiveTypeCoercion { * - IntegerType to FloatType * - LongType to FloatType * - LongType to DoubleType + * - DecimalType to Double + * + * This rule is only applied to Union/Except/Intersect */ object WidenTypes extends Rule[LogicalPlan] { - private[this] def widenOutputTypes(planName: String, left: LogicalPlan, right: LogicalPlan): - (LogicalPlan, LogicalPlan) = { - - // TODO: with fixed-precision decimals - val castedInput = left.output.zip(right.output).map { - // When a string is found on one side, make the other side a string too. - case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType => - (lhs, Alias(Cast(rhs, StringType), rhs.name)()) - case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType => - (Alias(Cast(lhs, StringType), lhs.name)(), rhs) + private[this] def widenOutputTypes( + planName: String, + left: LogicalPlan, + right: LogicalPlan): (LogicalPlan, LogicalPlan) = { + val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => - logDebug(s"Resolving mismatched $planName input ${lhs.dataType}, ${rhs.dataType}") - findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType => - val newLeft = - if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)() - val newRight = - if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)() - - (newLeft, newRight) - }.getOrElse { - // If there is no applicable conversion, leave expression unchanged. - (lhs, rhs) + (lhs.dataType, rhs.dataType) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (t: FractionalType, d: DecimalType) => + Some(DoubleType) + case (d: DecimalType, t: FractionalType) => + Some(DoubleType) + case _ => + findTightestCommonTypeToString(lhs.dataType, rhs.dataType) } - - case other => other + case other => None } - val (castedLeft, castedRight) = castedInput.unzip - - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logDebug(s"Widening numeric types in $planName $castedLeft ${left.output}") - Project(castedLeft, left) - } else { - left + def castOutput(plan: LogicalPlan): LogicalPlan = { + val casted = plan.output.zip(castedTypes).map { + case (hs, Some(dt)) if dt != hs.dataType => + Alias(Cast(hs, dt), hs.name)() + case (hs, _) => hs } + Project(casted, plan) + } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logDebug(s"Widening numeric types in $planName $castedRight ${right.output}") - Project(castedRight, right) - } else { - right - } - (newLeft, newRight) + if (castedTypes.exists(_.isDefined)) { + (castOutput(left), castOutput(right)) + } else { + (left, right) + } } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -334,144 +334,94 @@ object HiveTypeCoercion { * - SHORT gets turned into DECIMAL(5, 0) * - INT gets turned into DECIMAL(10, 0) * - LONG gets turned into DECIMAL(20, 0) - * - FLOAT and DOUBLE - * 1. Union, Intersect and Except operations: - * FLOAT gets turned into DECIMAL(7, 7), DOUBLE gets turned into DECIMAL(15, 15) (this is the - * same as Hive) - * 2. Other operation: - * FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, - * but note that unlimited decimals are considered bigger than doubles in WidenTypes) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE + * + * Note: Union/Except/Interact is handled by WidenTypes */ // scalastyle:on object DecimalPrecision extends Rule[LogicalPlan] { import scala.math.{max, min} - // Conversion rules for integer types into fixed-precision decimals - private val intTypeToFixed: Map[DataType, DecimalType] = Map( - ByteType -> DecimalType(3, 0), - ShortType -> DecimalType(5, 0), - IntegerType -> DecimalType(10, 0), - LongType -> DecimalType(20, 0) - ) - private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType - // Conversion rules for float and double into fixed-precision decimals - private val floatTypeToFixed: Map[DataType, DecimalType] = Map( - FloatType -> DecimalType(7, 7), - DoubleType -> DecimalType(15, 15) - ) - - private def castDecimalPrecision( - left: LogicalPlan, - right: LogicalPlan): (LogicalPlan, LogicalPlan) = { - val castedInput = left.output.zip(right.output).map { - case (lhs, rhs) if lhs.dataType != rhs.dataType => - (lhs.dataType, rhs.dataType) match { - case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => - // Decimals with precision/scale p1/s2 and p2/s2 will be promoted to - // DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)) - val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2)) - (Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - (Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)()) - case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) => - (Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs) - case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) => - (lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)()) - case _ => (lhs, rhs) - } - case other => other - } - - val (castedLeft, castedRight) = castedInput.unzip + // Returns the wider decimal type that's wider than both of them + def widerDecimalType(d1: DecimalType, d2: DecimalType): DecimalType = { + widerDecimalType(d1.precision, d1.scale, d2.precision, d2.scale) + } + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + def widerDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { + val scale = max(s1, s2) + val range = max(p1 - s1, p2 - s2) + DecimalType.bounded(range + scale, scale) + } - val newLeft = - if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - Project(castedLeft, left) - } else { - left - } + /** + * An expression used to wrap the children when promote the precision of DecimalType to avoid + * promote multiple times. + */ + case class ChangePrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def prettyName: String = "change_precision" + } - val newRight = - if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - Project(castedRight, right) - } else { - right - } - (newLeft, newRight) + def changePrecision(e: Expression, dataType: DataType): Expression = { + ChangePrecision(Cast(e, dataType)) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // fix decimal precision for union, intersect and except - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val (newLeft, newRight) = castDecimalPrecision(left, right) - Union(newLeft, newRight) - case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => - val (newLeft, newRight) = castDecimalPrecision(left, right) - Intersect(newLeft, newRight) - case e @ Except(left, right) if e.childrenResolved && !e.resolved => - val (newLeft, newRight) = castDecimalPrecision(left, right) - Except(newLeft, newRight) - // fix decimal precision for expressions case q => q.transformExpressions { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e + // Skip nodes who is already promoted + case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] => e + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + Add(changePrecision(e1, dt), changePrecision(e2, dt)) case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) - ) + val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + Subtract(changePrecision(e1, dt), changePrecision(e2, dt)) case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 + p2 + 1, s1 + s2) - ) + val dt = DecimalType.bounded(p1 + p2 + 1, s1 + s2) + Multiply(changePrecision(e1, dt), changePrecision(e2, dt)) case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) - ) + val dt = DecimalType.bounded(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + Divide(changePrecision(e1, dt), changePrecision(e2, dt)) case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + Cast(Remainder(changePrecision(e1, widerType), changePrecision(e2, widerType)), + resultType) case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => - Cast( - Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), - DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) - ) + val resultType = DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + // resultType may have lower precision, so we cast them into wider type first. + val widerType = widerDecimalType(p1, s1, p2, s2) + Cast(Pmod(changePrecision(e1, widerType), changePrecision(e2, widerType)), resultType) - // When we compare 2 decimal types with different precisions, cast them to the smallest - // common precision. case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 => - val resultType = DecimalType(max(p1, p2), max(s1, s2)) + val resultType = widerDecimalType(p1, s1, p2, s2) b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType))) // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { - case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => - b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) - case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => - b.makeCopy(Array(left, Cast(right, intTypeToFixed(t)))) + case (t: IntegralType, DecimalType.Fixed(p, s)) => + b.makeCopy(Array(Cast(left, DecimalType.forType(t)), right)) + case (DecimalType.Fixed(p, s), t: IntegralType) => + b.makeCopy(Array(left, Cast(right, DecimalType.forType(t)))) case (t, DecimalType.Fixed(p, s)) if isFloat(t) => b.makeCopy(Array(left, Cast(right, DoubleType))) case (DecimalType.Fixed(p, s), t) if isFloat(t) => @@ -485,7 +435,6 @@ object HiveTypeCoercion { // SUM and AVERAGE are handled by the implementations of those expressions } } - } /** @@ -563,7 +512,7 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e case Cast(e @ StringType(), t: IntegralType) => - Cast(Cast(e, DecimalType.Unlimited), t) + Cast(Cast(e, DecimalType.forType(LongType)), t) } } @@ -756,8 +705,8 @@ object HiveTypeCoercion { // Implicit cast among numeric types. When we reach here, input type is not acceptable. // If input is a numeric type but not decimal, and we expect a decimal type, - // cast the input to unlimited precision decimal. - case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited) + // cast the input to decimal. + case (d: NumericType, DecimalType) => Cast(e, DecimalType.forType(d)) // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long case (_: NumericType, target: NumericType) => Cast(e, target) @@ -766,7 +715,7 @@ object HiveTypeCoercion { case (TimestampType, DateType) => Cast(e, DateType) // Implicit cast from/to string - case (StringType, DecimalType) => Cast(e, DecimalType.Unlimited) + case (StringType, DecimalType) => Cast(e, DecimalType.SYSTEM_DEFAULT) case (StringType, target: NumericType) => Cast(e, target) case (StringType, DateType) => Cast(e, DateType) case (StringType, TimestampType) => Cast(e, TimestampType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 51821757967d2..a7e3a49327655 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -201,7 +201,7 @@ package object dsl { /** Creates a new AttributeReference of type decimal */ def decimal: AttributeReference = - AttributeReference(s, DecimalType.Unlimited, nullable = true)() + AttributeReference(s, DecimalType.SYSTEM_DEFAULT, nullable = true)() /** Creates a new AttributeReference of type decimal */ def decimal(precision: Int, scale: Int): AttributeReference = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e66cd828481bf..c66854d52c50b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -300,12 +300,7 @@ case class Cast(child: Expression, dataType: DataType) * NOTE: this modifies `value` in-place, so don't call it on external data. */ private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { - decimalType match { - case DecimalType.Unlimited => - value - case DecimalType.Fixed(precision, scale) => - if (value.changePrecision(precision, scale)) value else null - } + if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null } private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index b924af4cc84d8..88fb516e64aaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -36,14 +36,13 @@ case class Average(child: Expression) extends AlgebraicAggregate { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) private val resultType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) - case DecimalType.Unlimited => DecimalType.Unlimited + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 4, s + 4) case _ => DoubleType } private val sumDataType = child.dataType match { - case _ @ DecimalType() => DecimalType.Unlimited + case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _ => DoubleType } @@ -71,7 +70,14 @@ case class Average(child: Expression) extends AlgebraicAggregate { ) // If all input are nulls, currentCount will be 0 and we will get null after the division. - override val evaluateExpression = Cast(currentSum, resultType) / Cast(currentCount, resultType) + override val evaluateExpression = child.dataType match { + case DecimalType.Fixed(p, s) => + // increase the precision and scale to prevent precision loss + val dt = DecimalType.bounded(p + 14, s + 4) + Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType) + case _ => + Cast(currentSum, resultType) / Cast(currentCount, resultType) + } } case class Count(child: Expression) extends AlgebraicAggregate { @@ -255,15 +261,11 @@ case class Sum(child: Expression) extends AlgebraicAggregate { private val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) - case DecimalType.Unlimited => DecimalType.Unlimited + DecimalType.bounded(precision + 10, scale) case _ => child.dataType } - private val sumDataType = child.dataType match { - case _ @ DecimalType() => DecimalType.Unlimited - case _ => child.dataType - } + private val sumDataType = resultType private val currentSum = AttributeReference("currentSum", sumDataType)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index d3295b8bafa80..73fde4e9164d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -390,22 +390,21 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited + // Add 4 digits after decimal point, like Hive + DecimalType.bounded(precision + 4, scale + 4) case _ => DoubleType } override def asPartial: SplitEvaluation = { child.dataType match { - case DecimalType.Fixed(_, _) | DecimalType.Unlimited => - // Turn the child to unlimited decimals for calculation, before going back to fixed - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + case DecimalType.Fixed(precision, scale) => + val partialSum = Alias(Sum(child), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) - val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) + // partialSum already increase the precision by 10 + val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) + val castedCount = Sum(partialCount.toAttribute) SplitEvaluation( Cast(Divide(castedSum, castedCount), dataType), partialCount :: partialSum :: Nil) @@ -435,8 +434,8 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1) private val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) case _ => expr.dataType } @@ -454,10 +453,9 @@ case class AverageFunction(expr: Expression, base: AggregateExpression1) null } else { expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(Divide( - Cast(sum, DecimalType.Unlimited), - Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null) + case DecimalType.Fixed(precision, scale) => + val dt = DecimalType.bounded(precision + 14, scale + 4) + Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null) case _ => Divide( Cast(sum, dataType), @@ -481,9 +479,8 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited + // Add 10 digits left of decimal point, like Hive + DecimalType.bounded(precision + 10, scale) case _ => child.dataType } @@ -491,7 +488,7 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( Cast(CombineSum(partialSum.toAttribute), dataType), partialSum :: Nil) @@ -515,8 +512,8 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg private val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) case _ => expr.dataType } @@ -572,8 +569,8 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression1) private val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(precision, scale) => + DecimalType.bounded(precision + 10, scale) case _ => expr.dataType } @@ -608,9 +605,8 @@ case class SumDistinct(child: Expression) extends UnaryExpression with PartialAg override def nullable: Boolean = true override def dataType: DataType = child.dataType match { case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited + // Add 10 digits left of decimal point, like Hive + DecimalType.bounded(precision + 10, scale) case _ => child.dataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 05b5ad88fee8f..7c254a8750a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -88,6 +88,8 @@ abstract class BinaryArithmetic extends BinaryOperator { override def dataType: DataType = left.dataType + override lazy val resolved = childrenResolved && checkInputDataTypes().isSuccess + /** Name of the function for this expression on a [[Decimal]] type. */ def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") @@ -114,9 +116,6 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "+" - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -146,9 +145,6 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "-" - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -179,9 +175,6 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "*" override def decimalMethod: String = "$times" - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) @@ -195,9 +188,6 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic override def decimalMethod: String = "$div" override def nullable: Boolean = true - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot @@ -260,9 +250,6 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def decimalMethod: String = "remainder" override def nullable: Boolean = true - override lazy val resolved = - childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType) - private lazy val integral = dataType match { case i: IntegralType => i.integral.asInstanceOf[Integral[Any]] case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index f25ac32679587..85060b7893556 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -36,9 +36,9 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(UTF8String.fromString(s), StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) - case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) - case d: Decimal => Literal(d, DecimalType.Unlimited) + case d: BigDecimal => Literal(Decimal(d), DecimalType(d.precision, d.scale)) + case d: java.math.BigDecimal => Literal(Decimal(d), DecimalType(d.precision(), d.scale())) + case d: Decimal => Literal(d, DecimalType(d.precision, d.scale)) case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d06a7a2add754..c610f70d38437 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.plans -import org.apache.spark.sql.catalyst.expressions.{VirtualColumn, Attribute, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, VirtualColumn} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] { self: PlanType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index e98fd2583b931..591fb26e67c4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -106,7 +106,7 @@ object DataType { private def nameToType(name: String): DataType = { val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r name match { - case "decimal" => DecimalType.Unlimited + case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) case other => nonDecimalNameToType(other) } @@ -177,7 +177,7 @@ object DataType { | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.Unlimited + | "DecimalType()" ^^^ DecimalType.USER_DEFAULT | fixedDecimalType | "TimestampType" ^^^ TimestampType ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 6b43224feb1f2..6e081ea9237bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -48,7 +48,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)binary".r ^^^ BinaryType | "(?i)boolean".r ^^^ BooleanType | fixedDecimalType | - "(?i)decimal".r ^^^ DecimalType.Unlimited | + "(?i)decimal".r ^^^ DecimalType.USER_DEFAULT | "(?i)date".r ^^^ DateType | "(?i)timestamp".r ^^^ TimestampType | varchar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 377c75f6e85a5..26b24616d98ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -26,25 +26,46 @@ import org.apache.spark.sql.catalyst.expressions.Expression /** Precision parameters for a Decimal */ +@deprecated("Use DecimalType(precision, scale) directly", "1.5") case class PrecisionInfo(precision: Int, scale: Int) { if (scale > precision) { throw new AnalysisException( s"Decimal scale ($scale) cannot be greater than precision ($precision).") } + if (precision > DecimalType.MAX_PRECISION) { + throw new AnalysisException( + s"DecimalType can only support precision up to 38" + ) + } } /** * :: DeveloperApi :: * The data type representing `java.math.BigDecimal` values. - * A Decimal that might have fixed precision and scale, or unlimited values for these. + * A Decimal that must have fixed precision (the maximum number of digits) and scale (the number + * of digits on right side of dot). + * + * The precision can be up to 38, scale can also be up to 38 (less or equal to precision). + * + * The default precision and scale is (10, 0). * * Please use [[DataTypes.createDecimalType()]] to create a specific instance. */ @DeveloperApi -case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { +case class DecimalType(precision: Int, scale: Int) extends FractionalType { + + // default constructor for Java + def this(precision: Int) = this(precision, 0) + def this() = this(10) + + @deprecated("Use DecimalType(precision, scale) instead", "1.5") + def this(precisionInfo: Option[PrecisionInfo]) { + this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, + precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) + } - /** No-arg constructor for kryo. */ - protected def this() = this(null) + @deprecated("Use DecimalType.precision and DecimalType.scale instead", "1.5") + val precisionInfo = Some(PrecisionInfo(precision, scale)) private[sql] type InternalType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } @@ -53,18 +74,16 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT private[sql] val ordering = Decimal.DecimalIsFractional private[sql] val asIntegral = Decimal.DecimalAsIfIntegral - def precision: Int = precisionInfo.map(_.precision).getOrElse(-1) - - def scale: Int = precisionInfo.map(_.scale).getOrElse(-1) + override def typeName: String = s"decimal($precision,$scale)" - override def typeName: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal" - } + override def toString: String = s"DecimalType($precision,$scale)" - override def toString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" - case None => "DecimalType()" + private[sql] def isWiderThan(other: DataType): Boolean = other match { + case dt: DecimalType => + (precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale + case dt: IntegralType => + isWiderThan(DecimalType.forType(dt)) + case _ => false } /** @@ -72,10 +91,7 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT */ override def defaultSize: Int = 4096 - override def simpleString: String = precisionInfo match { - case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" - case None => "decimal(10,0)" - } + override def simpleString: String = s"decimal($precision,$scale)" private[spark] override def asNullable: DecimalType = this } @@ -83,8 +99,47 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT /** Extra factory methods and pattern matchers for Decimals */ object DecimalType extends AbstractDataType { + import scala.math.min + + val MAX_PRECISION = 38 + val MAX_SCALE = 38 + val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) + val USER_DEFAULT: DecimalType = DecimalType(10, 0) + + @deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5") + val Unlimited: DecimalType = SYSTEM_DEFAULT + + // The decimal types compatible with other numberic types + private[sql] val ByteDecimal = DecimalType(3, 0) + private[sql] val ShortDecimal = DecimalType(5, 0) + private[sql] val IntDecimal = DecimalType(10, 0) + private[sql] val LongDecimal = DecimalType(20, 0) + private[sql] val FloatDecimal = DecimalType(14, 7) + private[sql] val DoubleDecimal = DecimalType(30, 15) + + private[sql] def forType(dataType: DataType): DecimalType = dataType match { + case ByteType => ByteDecimal + case ShortType => ShortDecimal + case IntegerType => IntDecimal + case LongType => LongDecimal + case FloatType => FloatDecimal + case DoubleType => DoubleDecimal + } - override private[sql] def defaultConcreteType: DataType = Unlimited + @deprecated("please specify precision and scale", "1.5") + def apply(): DecimalType = USER_DEFAULT + + @deprecated("Use DecimalType(precision, scale) instead", "1.5") + def apply(precisionInfo: Option[PrecisionInfo]) { + this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, + precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) + } + + private[sql] def bounded(precision: Int, scale: Int): DecimalType = { + DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) + } + + override private[sql] def defaultConcreteType: DataType = SYSTEM_DEFAULT override private[sql] def acceptsType(other: DataType): Boolean = { other.isInstanceOf[DecimalType] @@ -92,31 +147,18 @@ object DecimalType extends AbstractDataType { override private[sql] def simpleString: String = "decimal" - val Unlimited: DecimalType = DecimalType(None) - private[sql] object Fixed { - def unapply(t: DecimalType): Option[(Int, Int)] = - t.precisionInfo.map(p => (p.precision, p.scale)) + def unapply(t: DecimalType): Option[(Int, Int)] = Some((t.precision, t.scale)) } private[sql] object Expression { def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { - case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case t: DecimalType => Some((t.precision, t.scale)) case _ => None } } - def apply(): DecimalType = Unlimited - - def apply(precision: Int, scale: Int): DecimalType = - DecimalType(Some(PrecisionInfo(precision, scale))) - def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] - - def isFixed(dataType: DataType): Boolean = dataType match { - case DecimalType.Fixed(_, _) => true - case _ => false - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 13aad467fa578..b9f2ad7ec0481 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -94,8 +94,8 @@ object RandomDataGenerator { case BooleanType => Some(() => rand.nextBoolean()) case DateType => Some(() => new java.sql.Date(rand.nextInt())) case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) - case DecimalType.Unlimited => Some( - () => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED)) + case DecimalType.Fixed(precision, scale) => Some( + () => BigDecimal.apply(rand.nextLong, rand.nextInt, new MathContext(precision))) case DoubleType => randomNumeric[Double]( rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index dbba93dba668e..677ba0a18040c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -50,9 +50,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite { for ( dataType <- DataTypeTestUtils.atomicTypes; nullable <- Seq(true, false) - if !dataType.isInstanceOf[DecimalType] || - dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty - ) { + if !dataType.isInstanceOf[DecimalType]) { test(s"$dataType (nullable=$nullable)") { testRandomDataGeneration(dataType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index b4b00f558463f..3b848cfdf737f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -102,7 +102,7 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("byteField", ByteType, nullable = true), StructField("booleanField", BooleanType, nullable = true), StructField("stringField", StringType, nullable = true), - StructField("decimalField", DecimalType.Unlimited, nullable = true), + StructField("decimalField", DecimalType.SYSTEM_DEFAULT, nullable = true), StructField("dateField", DateType, nullable = true), StructField("timestampField", TimestampType, nullable = true), StructField("binaryField", BinaryType, nullable = true))), @@ -216,7 +216,7 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(DoubleType === typeOfObject(1.7976931348623157E308)) // DecimalType - assert(DecimalType.Unlimited === + assert(DecimalType.SYSTEM_DEFAULT === typeOfObject(new java.math.BigDecimal("1.7976931348623157E318"))) // DateType @@ -229,19 +229,19 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(NullType === typeOfObject(null)) def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.Unlimited - case value: java.math.BigDecimal => DecimalType.Unlimited + case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT + case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT case _ => StringType } - assert(DecimalType.Unlimited === typeOfObject1( + assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( new BigInteger("92233720368547758070"))) - assert(DecimalType.Unlimited === typeOfObject1( + assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1( new java.math.BigDecimal("1.7976931348623157E318"))) assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType.Unlimited + case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT } intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 58df1de983a09..7e67427237a65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -55,7 +55,7 @@ object AnalysisSuite { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.Unlimited)(), + AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("e", ShortType)()) val nestedRelation = LocalRelation( @@ -158,7 +158,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.Unlimited)(), + AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) val plan = caseInsensitiveAnalyzer.execute( @@ -173,7 +173,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DecimalType.Unlimited) + assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double assert(pl(4).dataType == DoubleType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 7bac97b7894f5..f9f15e7a6608d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -34,7 +34,7 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { AttributeReference("i", IntegerType)(), AttributeReference("d1", DecimalType(2, 1))(), AttributeReference("d2", DecimalType(5, 2))(), - AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("f", FloatType)(), AttributeReference("b", DoubleType)() ) @@ -92,11 +92,11 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { } test("Comparison operations") { - checkComparison(EqualTo(i, d1), DecimalType(10, 1)) + checkComparison(EqualTo(i, d1), DecimalType(11, 1)) checkComparison(EqualNullSafe(d2, d1), DecimalType(5, 2)) - checkComparison(LessThan(i, d1), DecimalType(10, 1)) + checkComparison(LessThan(i, d1), DecimalType(11, 1)) checkComparison(LessThanOrEqual(d1, d2), DecimalType(5, 2)) - checkComparison(GreaterThan(d2, u), DecimalType.Unlimited) + checkComparison(GreaterThan(d2, u), DecimalType.SYSTEM_DEFAULT) checkComparison(GreaterThanOrEqual(d1, f), DoubleType) checkComparison(GreaterThan(d2, d2), DecimalType(5, 2)) } @@ -106,12 +106,12 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkUnion(i, d2, DecimalType(12, 2)) checkUnion(d1, d2, DecimalType(5, 2)) checkUnion(d2, d1, DecimalType(5, 2)) - checkUnion(d1, f, DecimalType(8, 7)) - checkUnion(f, d2, DecimalType(10, 7)) - checkUnion(d1, b, DecimalType(16, 15)) - checkUnion(b, d2, DecimalType(18, 15)) - checkUnion(d1, u, DecimalType.Unlimited) - checkUnion(u, d2, DecimalType.Unlimited) + checkUnion(d1, f, DoubleType) + checkUnion(f, d2, DoubleType) + checkUnion(d1, b, DoubleType) + checkUnion(b, d2, DoubleType) + checkUnion(d1, u, DecimalType.SYSTEM_DEFAULT) + checkUnion(u, d2, DecimalType.SYSTEM_DEFAULT) } test("bringing in primitive types") { @@ -125,13 +125,33 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkType(Add(d1, Cast(i, DoubleType)), DoubleType) } - test("unlimited decimals make everything else cast up") { - for (expr <- Seq(d1, d2, i, f, u)) { - checkType(Add(expr, u), DecimalType.Unlimited) - checkType(Subtract(expr, u), DecimalType.Unlimited) - checkType(Multiply(expr, u), DecimalType.Unlimited) - checkType(Divide(expr, u), DecimalType.Unlimited) - checkType(Remainder(expr, u), DecimalType.Unlimited) + test("maximum decimals") { + for (expr <- Seq(d1, d2, i, u)) { + checkType(Add(expr, u), DecimalType.SYSTEM_DEFAULT) + checkType(Subtract(expr, u), DecimalType.SYSTEM_DEFAULT) + } + + checkType(Multiply(d1, u), DecimalType(38, 19)) + checkType(Multiply(d2, u), DecimalType(38, 20)) + checkType(Multiply(i, u), DecimalType(38, 18)) + checkType(Multiply(u, u), DecimalType(38, 36)) + + checkType(Divide(u, d1), DecimalType(38, 21)) + checkType(Divide(u, d2), DecimalType(38, 24)) + checkType(Divide(u, i), DecimalType(38, 29)) + checkType(Divide(u, u), DecimalType(38, 38)) + + checkType(Remainder(d1, u), DecimalType(19, 18)) + checkType(Remainder(d2, u), DecimalType(21, 18)) + checkType(Remainder(i, u), DecimalType(28, 18)) + checkType(Remainder(u, u), DecimalType.SYSTEM_DEFAULT) + + for (expr <- Seq(f, b)) { + checkType(Add(expr, u), DoubleType) + checkType(Subtract(expr, u), DoubleType) + checkType(Multiply(expr, u), DoubleType) + checkType(Divide(expr, u), DoubleType) + checkType(Remainder(expr, u), DoubleType) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 835220c563f41..d0fb95b580ad2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -35,14 +35,14 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(NullType, NullType, NullType) shouldCast(NullType, IntegerType, IntegerType) - shouldCast(NullType, DecimalType, DecimalType.Unlimited) + shouldCast(NullType, DecimalType, DecimalType.SYSTEM_DEFAULT) shouldCast(ByteType, IntegerType, IntegerType) shouldCast(IntegerType, IntegerType, IntegerType) shouldCast(IntegerType, LongType, LongType) - shouldCast(IntegerType, DecimalType, DecimalType.Unlimited) + shouldCast(IntegerType, DecimalType, DecimalType(10, 0)) shouldCast(LongType, IntegerType, IntegerType) - shouldCast(LongType, DecimalType, DecimalType.Unlimited) + shouldCast(LongType, DecimalType, DecimalType(20, 0)) shouldCast(DateType, TimestampType, TimestampType) shouldCast(TimestampType, DateType, DateType) @@ -71,8 +71,8 @@ class HiveTypeCoercionSuite extends PlanTest { shouldCast(IntegerType, TypeCollection(StringType, BinaryType), StringType) shouldCast(IntegerType, TypeCollection(BinaryType, StringType), StringType) - shouldCast( - DecimalType.Unlimited, TypeCollection(IntegerType, DecimalType), DecimalType.Unlimited) + shouldCast(DecimalType.SYSTEM_DEFAULT, + TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT) shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) @@ -82,7 +82,7 @@ class HiveTypeCoercionSuite extends PlanTest { // NumericType should not be changed when function accepts any of them. Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, - DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe => + DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)).foreach { tpe => shouldCast(tpe, NumericType, tpe) } @@ -107,8 +107,8 @@ class HiveTypeCoercionSuite extends PlanTest { shouldNotCast(IntegerType, TimestampType) shouldNotCast(LongType, DateType) shouldNotCast(LongType, TimestampType) - shouldNotCast(DecimalType.Unlimited, DateType) - shouldNotCast(DecimalType.Unlimited, TimestampType) + shouldNotCast(DecimalType.SYSTEM_DEFAULT, DateType) + shouldNotCast(DecimalType.SYSTEM_DEFAULT, TimestampType) shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) @@ -160,14 +160,6 @@ class HiveTypeCoercionSuite extends PlanTest { widenTest(LongType, FloatType, Some(FloatType)) widenTest(LongType, DoubleType, Some(DoubleType)) - // Casting up to unlimited-precision decimal - widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) - widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) - widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited)) - widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited)) - widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited)) - widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited)) - // No up-casting for fixed-precision decimal (this is handled by arithmetic rules) widenTest(DecimalType(2, 1), DecimalType(3, 2), None) widenTest(DecimalType(2, 1), DoubleType, None) @@ -242,9 +234,9 @@ class HiveTypeCoercionSuite extends PlanTest { :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) :: Nil), - Coalesce(Cast(Literal(1L), DecimalType()) - :: Cast(Literal(1), DecimalType()) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) + Coalesce(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) } @@ -323,7 +315,7 @@ class HiveTypeCoercionSuite extends PlanTest { val left = LocalRelation( AttributeReference("i", IntegerType)(), - AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("u", DecimalType.SYSTEM_DEFAULT)(), AttributeReference("b", ByteType)(), AttributeReference("d", DoubleType)()) val right = LocalRelation( @@ -333,7 +325,7 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("l", LongType)()) val wt = HiveTypeCoercion.WidenTypes - val expectedTypes = Seq(StringType, DecimalType.Unlimited, FloatType, DoubleType) + val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) val r1 = wt(Union(left, right)).asInstanceOf[Union] val r2 = wt(Except(left, right)).asInstanceOf[Except] @@ -353,13 +345,13 @@ class HiveTypeCoercionSuite extends PlanTest { } } - val dp = HiveTypeCoercion.DecimalPrecision + val dp = HiveTypeCoercion.WidenTypes val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) val right1 = LocalRelation( AttributeReference("r", DecimalType(5, 5))()) - val expectedType1 = Seq(DecimalType(math.max(8, 5) + math.max(10 - 8, 5 - 5), math.max(8, 5))) + val expectedType1 = Seq(DecimalType(10, 8)) val r1 = dp(Union(left1, right1)).asInstanceOf[Union] val r2 = dp(Except(left1, right1)).asInstanceOf[Except] @@ -372,12 +364,11 @@ class HiveTypeCoercionSuite extends PlanTest { checkOutput(r3.left, expectedType1) checkOutput(r3.right, expectedType1) - val plan1 = LocalRelation( - AttributeReference("l", DecimalType(10, 10))()) + val plan1 = LocalRelation(AttributeReference("l", DecimalType(10, 5))()) val rightTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) - val expectedTypes = Seq(DecimalType(3, 0), DecimalType(5, 0), DecimalType(10, 0), - DecimalType(20, 0), DecimalType(7, 7), DecimalType(15, 15)) + val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5), + DecimalType(25, 5), DoubleType, DoubleType) rightTypes.zip(expectedTypes).map { case (rType, expectedType) => val plan2 = LocalRelation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index ccf448eee0688..facf65c155148 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -185,7 +185,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1, 1.0) checkCast(123, "123") - checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) @@ -203,7 +203,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(1L, 1.0) checkCast(123L, "123") - checkEvaluation(cast(123L, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) @@ -225,7 +225,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(cast(1000, TimestampType), LongType), 1.toLong) checkEvaluation(cast(cast(-1200, TimestampType), LongType), -2.toLong) - checkEvaluation(cast(123, DecimalType.Unlimited), Decimal(123)) + checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123)) checkEvaluation(cast(123, DecimalType(3, 1)), null) checkEvaluation(cast(123, DecimalType(2, 0)), null) @@ -267,7 +267,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast("abcdef", IntegerType).nullable === true) assert(cast("abcdef", ShortType).nullable === true) assert(cast("abcdef", ByteType).nullable === true) - assert(cast("abcdef", DecimalType.Unlimited).nullable === true) + assert(cast("abcdef", DecimalType.USER_DEFAULT).nullable === true) assert(cast("abcdef", DecimalType(4, 2)).nullable === true) assert(cast("abcdef", DoubleType).nullable === true) assert(cast("abcdef", FloatType).nullable === true) @@ -291,9 +291,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { c.getTimeInMillis * 1000) checkEvaluation(cast("abdef", StringType), "abdef") - checkEvaluation(cast("abdef", DecimalType.Unlimited), null) + checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) checkEvaluation(cast("abdef", TimestampType), null) - checkEvaluation(cast("12.65", DecimalType.Unlimited), Decimal(12.65)) + checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65)) checkEvaluation(cast(cast(sd, DateType), StringType), sd) checkEvaluation(cast(cast(d, StringType), DateType), 0) @@ -311,20 +311,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { 5.toLong) checkEvaluation( cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType), - DecimalType.Unlimited), LongType), StringType), ShortType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), 0.toShort) checkEvaluation( cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType), - DecimalType.Unlimited), LongType), StringType), ShortType), + DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType), null) - checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.Unlimited), + checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT), ByteType), TimestampType), LongType), StringType), ShortType), 0.toShort) checkEvaluation(cast("23", DoubleType), 23d) checkEvaluation(cast("23", IntegerType), 23) checkEvaluation(cast("23", FloatType), 23f) - checkEvaluation(cast("23", DecimalType.Unlimited), Decimal(23)) + checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23)) checkEvaluation(cast("23", ByteType), 23.toByte) checkEvaluation(cast("23", ShortType), 23.toShort) checkEvaluation(cast("2012-12-11", DoubleType), null) @@ -338,7 +338,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d) checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24) checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f) - checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.Unlimited)), Decimal(24)) + checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.USER_DEFAULT)), Decimal(24)) checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte) checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort) } @@ -362,10 +362,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { // - Values that would overflow the target precision should turn into null // - Because of this, casts to fixed-precision decimals should be nullable - assert(cast(123, DecimalType.Unlimited).nullable === false) - assert(cast(10.03f, DecimalType.Unlimited).nullable === true) - assert(cast(10.03, DecimalType.Unlimited).nullable === true) - assert(cast(Decimal(10.03), DecimalType.Unlimited).nullable === false) + assert(cast(123, DecimalType.USER_DEFAULT).nullable === true) + assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true) assert(cast(123, DecimalType(2, 1)).nullable === true) assert(cast(10.03f, DecimalType(2, 1)).nullable === true) @@ -373,7 +373,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) - checkEvaluation(cast(10.03, DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0)) checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10)) @@ -383,7 +383,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0)) checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null) - checkEvaluation(cast(10.05, DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05)) checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05)) checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1)) checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10)) @@ -409,10 +409,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) - checkEvaluation(cast(Double.NaN, DecimalType.Unlimited), null) - checkEvaluation(cast(1.0 / 0.0, DecimalType.Unlimited), null) - checkEvaluation(cast(Float.NaN, DecimalType.Unlimited), null) - checkEvaluation(cast(1.0f / 0.0f, DecimalType.Unlimited), null) + checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null) + checkEvaluation(cast(1.0f / 0.0f, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null) checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) @@ -427,7 +427,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(d, LongType), null) checkEvaluation(cast(d, FloatType), null) checkEvaluation(cast(d, DoubleType), null) - checkEvaluation(cast(d, DecimalType.Unlimited), null) + checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(d, DecimalType(10, 2)), null) checkEvaluation(cast(d, StringType), "1970-01-01") checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00") @@ -454,7 +454,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast(cast(millis.toDouble / 1000, TimestampType), DoubleType), millis.toDouble / 1000) checkEvaluation( - cast(cast(Decimal(1), TimestampType), DecimalType.Unlimited), + cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT), Decimal(1)) // A test for higher precision than millis diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index afa143bd5f331..b31d6661c8c1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -60,7 +60,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper testIf(_.toFloat, FloatType) testIf(_.toDouble, DoubleType) - testIf(Decimal(_), DecimalType.Unlimited) + testIf(Decimal(_), DecimalType.USER_DEFAULT) testIf(identity, DateType) testIf(_.toLong, TimestampType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index d924ff7a102f6..f6404d21611e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -33,7 +33,7 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, LongType), null) checkEvaluation(Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, BinaryType), null) - checkEvaluation(Literal.create(null, DecimalType()), null) + checkEvaluation(Literal.create(null, DecimalType.USER_DEFAULT), null) checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null) checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null) checkEvaluation(Literal.create(null, StructType(Seq.empty)), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index 0728f6695c39d..9efe44c83293d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -30,7 +30,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testFunc(1L, LongType) testFunc(1.0F, FloatType) testFunc(1.0, DoubleType) - testFunc(Decimal(1.5), DecimalType.Unlimited) + testFunc(Decimal(1.5), DecimalType(2, 1)) testFunc(new java.sql.Date(10), DateType) testFunc(new java.sql.Timestamp(10), TimestampType) testFunc("abcd", StringType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 7566cb59e34ee..48b7dc57451a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -121,6 +121,8 @@ class UnsafeFixedWidthAggregationMapSuite }.toSet seenKeys.size should be (groupKeys.size) seenKeys should be (groupKeys) + + map.free() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 8819234e78e60..a5d9806c20463 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -145,7 +145,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { DoubleType, StringType, BinaryType - // DecimalType.Unlimited, + // DecimalType.Default, // ArrayType(IntegerType) ) val converter = new UnsafeRowConverter(fieldTypes) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala index c6171b7b6916d..1ba290753ce48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeParserSuite.scala @@ -44,7 +44,7 @@ class DataTypeParserSuite extends SparkFunSuite { checkDataType("float", FloatType) checkDataType("dOUBle", DoubleType) checkDataType("decimal(10, 5)", DecimalType(10, 5)) - checkDataType("decimal", DecimalType.Unlimited) + checkDataType("decimal", DecimalType.USER_DEFAULT) checkDataType("DATE", DateType) checkDataType("timestamp", TimestampType) checkDataType("string", StringType) @@ -87,7 +87,7 @@ class DataTypeParserSuite extends SparkFunSuite { StructType( StructField("struct", StructType( - StructField("deciMal", DecimalType.Unlimited, true) :: + StructField("deciMal", DecimalType.USER_DEFAULT, true) :: StructField("anotherDecimal", DecimalType(5, 2), true) :: Nil), true) :: StructField("MAP", MapType(TimestampType, StringType), true) :: StructField("arrAy", ArrayType(DoubleType, true), true) :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 14e7b4a9561b6..88b221cd81d74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -185,7 +185,7 @@ class DataTypeSuite extends SparkFunSuite { checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) checkDataTypeJsonRepr(DecimalType(10, 5)) - checkDataTypeJsonRepr(DecimalType.Unlimited) + checkDataTypeJsonRepr(DecimalType.SYSTEM_DEFAULT) checkDataTypeJsonRepr(DateType) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) @@ -219,7 +219,7 @@ class DataTypeSuite extends SparkFunSuite { checkDefaultSize(FloatType, 4) checkDefaultSize(DoubleType, 8) checkDefaultSize(DecimalType(10, 5), 4096) - checkDefaultSize(DecimalType.Unlimited, 4096) + checkDefaultSize(DecimalType.SYSTEM_DEFAULT, 4096) checkDefaultSize(DateType, 4) checkDefaultSize(TimestampType, 8) checkDefaultSize(StringType, 4096) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala index 32632b5d6e342..0ee9ddac815b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala @@ -34,7 +34,7 @@ object DataTypeTestUtils { * decimal types. */ val fractionalTypes: Set[FractionalType] = Set( - DecimalType(precisionInfo = None), + DecimalType.SYSTEM_DEFAULT, DecimalType(2, 1), DoubleType, FloatType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 6e2a6525bf17e..b25dcbca82b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -996,7 +996,7 @@ class ColumnName(name: String) extends Column(name) { * Creates a new [[StructField]] of type decimal. * @since 1.3.0 */ - def decimal: StructField = StructField(name, DecimalType.Unlimited) + def decimal: StructField = StructField(name, DecimalType.USER_DEFAULT) /** * Creates a new [[StructField]] of type decimal. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index fc72360c88fe1..9d8415f06399c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -375,7 +375,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) extends NativeColumnType( - DecimalType(Some(PrecisionInfo(precision, scale))), + DecimalType(precision, scale), 10, FIXED_DECIMAL.defaultSize) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 16176abe3a51d..5ed158b3d2912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -21,9 +21,9 @@ import org.apache.spark.TaskContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.trees._ import org.apache.spark.sql.types._ case class AggregateEvaluation( @@ -92,8 +92,8 @@ case class GeneratedAggregate( case s @ Sum(expr) => val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 10, s) case _ => expr.dataType } @@ -121,8 +121,8 @@ case class GeneratedAggregate( case cs @ CombineSum(expr) => val calcType = expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 10, s) case _ => expr.dataType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 6b4a359db22d1..9d0fa894b9942 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -25,6 +25,7 @@ import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.types._ @@ -236,7 +237,7 @@ private[sql] object PartitioningUtils { /** * Converts a string to a [[Literal]] with automatic type inference. Currently only supports - * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.Unlimited]], and + * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and * [[StringType]]. */ private[sql] def inferPartitionColumnValue( @@ -249,7 +250,7 @@ private[sql] object PartitioningUtils { .orElse(Try(Literal.create(JLong.parseLong(raw), LongType))) // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) - .orElse(Try(Literal.create(new JBigDecimal(raw), DecimalType.Unlimited))) + .orElse(Try(Literal(new JBigDecimal(raw)))) // Then falls back to string .getOrElse { if (raw == defaultPartitionName) { @@ -268,7 +269,7 @@ private[sql] object PartitioningUtils { } private val upCastingOrder: Seq[DataType] = - Seq(NullType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited, StringType) + Seq(NullType, IntegerType, LongType, FloatType, DoubleType, StringType) /** * Given a collection of [[Literal]]s, resolves possible type conflicts by up-casting "lower" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 7a27fba1780b9..3cf70db6b7b09 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -66,8 +66,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.DATALINK => null case java.sql.Types.DATE => DateType case java.sql.Types.DECIMAL - if precision != 0 || scale != 0 => DecimalType(precision, scale) - case java.sql.Types.DECIMAL => DecimalType.Unlimited + if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) + case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT case java.sql.Types.DISTINCT => null case java.sql.Types.DOUBLE => DoubleType case java.sql.Types.FLOAT => FloatType @@ -80,8 +80,8 @@ private[sql] object JDBCRDD extends Logging { case java.sql.Types.NCLOB => StringType case java.sql.Types.NULL => null case java.sql.Types.NUMERIC - if precision != 0 || scale != 0 => DecimalType(precision, scale) - case java.sql.Types.NUMERIC => DecimalType.Unlimited + if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) + case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT case java.sql.Types.NVARCHAR => StringType case java.sql.Types.OTHER => null case java.sql.Types.REAL => DoubleType @@ -314,7 +314,7 @@ private[sql] class JDBCRDD( abstract class JDBCConversion case object BooleanConversion extends JDBCConversion case object DateConversion extends JDBCConversion - case class DecimalConversion(precisionInfo: Option[(Int, Int)]) extends JDBCConversion + case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion case object DoubleConversion extends JDBCConversion case object FloatConversion extends JDBCConversion case object IntegerConversion extends JDBCConversion @@ -331,8 +331,7 @@ private[sql] class JDBCRDD( schema.fields.map(sf => sf.dataType match { case BooleanType => BooleanConversion case DateType => DateConversion - case DecimalType.Unlimited => DecimalConversion(None) - case DecimalType.Fixed(d) => DecimalConversion(Some(d)) + case DecimalType.Fixed(p, s) => DecimalConversion(p, s) case DoubleType => DoubleConversion case FloatType => FloatConversion case IntegerType => IntegerConversion @@ -399,20 +398,13 @@ private[sql] class JDBCRDD( // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then // retrieve it, you will get wrong result 199.99. // So it is needed to set precision and scale for Decimal based on JDBC metadata. - case DecimalConversion(Some((p, s))) => + case DecimalConversion(p, s) => val decimalVal = rs.getBigDecimal(pos) if (decimalVal == null) { mutableRow.update(i, null) } else { mutableRow.update(i, Decimal(decimalVal, p, s)) } - case DecimalConversion(None) => - val decimalVal = rs.getBigDecimal(pos) - if (decimalVal == null) { - mutableRow.update(i, null) - } else { - mutableRow.update(i, Decimal(decimalVal)) - } case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos)) case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos)) case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index f7ea852fe7f58..035e0510080ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -89,8 +89,7 @@ package object jdbc { case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case DecimalType.Unlimited => stmt.setBigDecimal(i + 1, - row.getAs[java.math.BigDecimal](i)) + case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) case _ => throw new IllegalArgumentException( s"Can't translate non-null value for field $i") } @@ -145,7 +144,7 @@ package object jdbc { case BinaryType => "BLOB" case TimestampType => "TIMESTAMP" case DateType => "DATE" - case DecimalType.Unlimited => "DECIMAL(40,20)" + case t: DecimalType => s"DECIMAL(${t.precision}},${t.scale}})" case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") }) val nullable = if (field.nullable) "" else "NOT NULL" @@ -177,7 +176,7 @@ package object jdbc { case BinaryType => java.sql.Types.BLOB case TimestampType => java.sql.Types.TIMESTAMP case DateType => java.sql.Types.DATE - case DecimalType.Unlimited => java.sql.Types.DECIMAL + case t: DecimalType => java.sql.Types.DECIMAL case _ => throw new IllegalArgumentException( s"Can't translate null value for field $field") }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index afe2c6c11ac69..0eb3b04007f8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -113,7 +113,7 @@ private[sql] object InferSchema { case INT | LONG => LongType // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. - case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited + case BIG_INTEGER | BIG_DECIMAL => DecimalType.SYSTEM_DEFAULT case FLOAT | DOUBLE => DoubleType } @@ -168,8 +168,13 @@ private[sql] object InferSchema { HiveTypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { - case (other: DataType, NullType) => other - case (NullType, other: DataType) => other + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, t: DecimalType) => + if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + case (t: DecimalType, DoubleType) => + if (t == DecimalType.SYSTEM_DEFAULT) t else DoubleType + case (StructType(fields1), StructType(fields2)) => val newFields = (fields1 ++ fields2).groupBy(field => field.name).map { case (name, fieldTypes) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 1ea6926af6d5b..1d3a0d15d336e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -439,10 +439,6 @@ private[parquet] class CatalystSchemaConverter( .length(minBytesForPrecision(precision)) .named(field.name) - case dec @ DecimalType.Unlimited if followParquetFormatSpec => - throw new AnalysisException( - s"Data type $dec is not supported. Decimal precision and scale must be specified.") - // =================================================== // ArrayType and MapType (for Spark versions <= 1.4.x) // =================================================== diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index e8851ddb68026..d1040bf5562a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -261,10 +261,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + if (d.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") } - writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision) + writeDecimal(value.asInstanceOf[Decimal], d.precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -415,10 +415,10 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) case d: DecimalType => - if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + if (d.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") } - writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) + writeDecimal(record(index).asInstanceOf[Decimal], d.precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index fcb8f5499cf84..cb84e78d628ca 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -22,7 +22,6 @@ import java.util.Arrays; import java.util.List; -import org.apache.spark.sql.test.TestSQLContext$; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -31,8 +30,14 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.*; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.test.TestSQLContext$; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -159,7 +164,8 @@ public void applySchemaToJSON() { "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); List fields = new ArrayList(7); - fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true)); + fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(38, 18), + true)); fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true)); fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true)); fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true)); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 01bc23277fa88..037e2048a8631 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -148,7 +148,7 @@ class InMemoryColumnarQuerySuite extends QueryTest { val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3d71deb13e884..845ce669f0b33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -109,7 +109,7 @@ class PlannerSuite extends SparkFunSuite { FloatType :: DoubleType :: DecimalType(10, 5) :: - DecimalType.Unlimited :: + DecimalType.SYSTEM_DEFAULT :: DateType :: TimestampType :: StringType :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala index 4a53fadd7e099..54f82f89ed18a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala @@ -54,7 +54,7 @@ class SparkSqlSerializer2DataTypeSuite extends SparkFunSuite { checkSupported(StringType, isSupported = true) checkSupported(BinaryType, isSupported = true) checkSupported(DecimalType(10, 5), isSupported = true) - checkSupported(DecimalType.Unlimited, isSupported = true) + checkSupported(DecimalType.SYSTEM_DEFAULT, isSupported = true) // If NullType is the only data type in the schema, we do not support it. checkSupported(NullType, isSupported = false) @@ -86,7 +86,7 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll val supportedTypes = Seq(StringType, BinaryType, NullType, BooleanType, ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5), + FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType(6, 5), DateType, TimestampType) val fields = supportedTypes.zipWithIndex.map { case (dataType, index) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 0f82f13088d39..42f2449afb0f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -134,7 +134,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))" + conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(38, 18))" ).executeUpdate() conn.prepareStatement("insert into test.flttypes values (" + "1.0000000000000002220446049250313080847263336181640625, " @@ -152,7 +152,7 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { s""" |create table test.nulltypes (a INT, b BOOLEAN, c TINYINT, d BINARY(20), e VARCHAR(20), |f VARCHAR_IGNORECASE(20), g CHAR(20), h BLOB, i CLOB, j TIME, k DATE, l TIMESTAMP, - |m DOUBLE, n REAL, o DECIMAL(40, 20)) + |m DOUBLE, n REAL, o DECIMAL(38, 18)) """.stripMargin.replaceAll("\n", " ")).executeUpdate() conn.prepareStatement("insert into test.nulltypes values (" + "null, null, null, null, null, null, null, null, null, " @@ -357,14 +357,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter { test("H2 floating-point types") { val rows = sql("SELECT * FROM flttypes").collect() - assert(rows(0).getDouble(0) === 1.00000000000000022) // Yes, I meant ==. - assert(rows(0).getDouble(1) === 1.00000011920928955) // Yes, I meant ==. - assert(rows(0).getAs[BigDecimal](2) - .equals(new BigDecimal("123456789012345.54321543215432100000"))) - assert(rows(0).schema.fields(2).dataType === DecimalType(40, 20)) - val compareDecimal = sql("SELECT C FROM flttypes where C > C - 1").collect() - assert(compareDecimal(0).getAs[BigDecimal](0) - .equals(new BigDecimal("123456789012345.54321543215432100000"))) + assert(rows(0).getDouble(0) === 1.00000000000000022) + assert(rows(0).getDouble(1) === 1.00000011920928955) + assert(rows(0).getAs[BigDecimal](2) === + new BigDecimal("123456789012345.543215432154321000")) + assert(rows(0).schema.fields(2).dataType === DecimalType(38, 18)) + val result = sql("SELECT C FROM flttypes where C > C - 1").collect() + assert(result(0).getAs[BigDecimal](0) === + new BigDecimal("123456789012345.543215432154321000")) } test("SQL query as table name") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 1d04513a44672..3ac312d6f4c50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -63,18 +63,18 @@ class JsonSuite extends QueryTest with TestJsonData { checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) checkTypePromotion( - Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited)) + Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.SYSTEM_DEFAULT)) val longNumber: Long = 9223372036854775807L checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) checkTypePromotion( - Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited)) + Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.SYSTEM_DEFAULT)) val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) checkTypePromotion( - Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) + Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.SYSTEM_DEFAULT)) checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber)), enforceCorrectType(intNumber, TimestampType)) @@ -115,7 +115,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(NullType, IntegerType, IntegerType) checkDataType(NullType, LongType, LongType) checkDataType(NullType, DoubleType, DoubleType) - checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(NullType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(NullType, StringType, StringType) checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) checkDataType(NullType, StructType(Nil), StructType(Nil)) @@ -126,7 +126,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(BooleanType, IntegerType, StringType) checkDataType(BooleanType, LongType, StringType) checkDataType(BooleanType, DoubleType, StringType) - checkDataType(BooleanType, DecimalType.Unlimited, StringType) + checkDataType(BooleanType, DecimalType.SYSTEM_DEFAULT, StringType) checkDataType(BooleanType, StringType, StringType) checkDataType(BooleanType, ArrayType(IntegerType), StringType) checkDataType(BooleanType, StructType(Nil), StringType) @@ -135,7 +135,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType(IntegerType, IntegerType, IntegerType) checkDataType(IntegerType, LongType, LongType) checkDataType(IntegerType, DoubleType, DoubleType) - checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(IntegerType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(IntegerType, StringType, StringType) checkDataType(IntegerType, ArrayType(IntegerType), StringType) checkDataType(IntegerType, StructType(Nil), StringType) @@ -143,23 +143,24 @@ class JsonSuite extends QueryTest with TestJsonData { // LongType checkDataType(LongType, LongType, LongType) checkDataType(LongType, DoubleType, DoubleType) - checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(LongType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(LongType, StringType, StringType) checkDataType(LongType, ArrayType(IntegerType), StringType) checkDataType(LongType, StructType(Nil), StringType) // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(DoubleType, DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) - // DoubleType - checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited) - checkDataType(DecimalType.Unlimited, StringType, StringType) - checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType) - checkDataType(DecimalType.Unlimited, StructType(Nil), StringType) + // DecimalType + checkDataType(DecimalType.SYSTEM_DEFAULT, DecimalType.SYSTEM_DEFAULT, + DecimalType.SYSTEM_DEFAULT) + checkDataType(DecimalType.SYSTEM_DEFAULT, StringType, StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, ArrayType(IntegerType), StringType) + checkDataType(DecimalType.SYSTEM_DEFAULT, StructType(Nil), StringType) // StringType checkDataType(StringType, StringType, StringType) @@ -213,7 +214,7 @@ class JsonSuite extends QueryTest with TestJsonData { checkDataType( StructType( StructField("f1", IntegerType, true) :: Nil), - DecimalType.Unlimited, + DecimalType.SYSTEM_DEFAULT, StringType) } @@ -240,7 +241,7 @@ class JsonSuite extends QueryTest with TestJsonData { val jsonDF = ctx.read.json(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -270,7 +271,7 @@ class JsonSuite extends QueryTest with TestJsonData { val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, true), true), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, true), true), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, true), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType.SYSTEM_DEFAULT, true), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, true), true) :: StructField("arrayOfInteger", ArrayType(LongType, true), true) :: @@ -284,7 +285,7 @@ class JsonSuite extends QueryTest with TestJsonData { StructField("field3", StringType, true) :: Nil), true), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: + StructField("field2", DecimalType.SYSTEM_DEFAULT, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(LongType, true), true) :: StructField("field2", ArrayType(StringType, true), true) :: Nil), true) :: Nil) @@ -385,7 +386,7 @@ class JsonSuite extends QueryTest with TestJsonData { val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType.Unlimited, true) :: + StructField("num_num_2", DecimalType.SYSTEM_DEFAULT, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -421,11 +422,11 @@ class JsonSuite extends QueryTest with TestJsonData { Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) - // Widening to DecimalType + // Widening to DoubleType checkAnswer( - sql("select num_num_2 + 1.2 from jsonTable where num_num_2 > 1.1"), - Row(new java.math.BigDecimal("21474836472.1")) :: - Row(new java.math.BigDecimal("92233720368547758071.2")) :: Nil + sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), + Row(21474836472.2) :: + Row(92233720368547758071.3) :: Nil ) // Widening to DoubleType @@ -442,8 +443,8 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( - sql("select num_str + 1.2 from jsonTable where num_str > 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758061.2").doubleValue) + sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) ) // String and Boolean conflict: resolve the type as string. @@ -489,9 +490,9 @@ class JsonSuite extends QueryTest with TestJsonData { // in the Project. checkAnswer( jsonDF. - where('num_str > BigDecimal("92233720368547758060")). + where('num_str >= BigDecimal("92233720368547758060")). select(('num_str + 1.2).as("num")), - Row(new java.math.BigDecimal("92233720368547758061.2")) + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue()) ) // The following test will fail. The type of num_str is StringType. @@ -610,7 +611,7 @@ class JsonSuite extends QueryTest with TestJsonData { val jsonDF = ctx.read.json(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", LongType, true) :: @@ -668,7 +669,7 @@ class JsonSuite extends QueryTest with TestJsonData { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( - StructField("bigInteger", DecimalType.Unlimited, true) :: + StructField("bigInteger", DecimalType.SYSTEM_DEFAULT, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 7b16eba00d6fb..3a5b860484e86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -122,14 +122,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { sqlContext.read.parquet(dir.getCanonicalPath).collect() } } - - // Unlimited-length decimals are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType.Unlimited).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() - } - } } test("date type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 4f98776b91160..7f16b1125c7a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -509,7 +509,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { FloatType, DoubleType, DecimalType(10, 5), - DecimalType.Unlimited, + DecimalType.SYSTEM_DEFAULT, DateType, TimestampType, StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 54e1efb6e36e7..da53ec16b5c41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -44,7 +44,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo StructField("doubleType", DoubleType, nullable = false), StructField("bigintType", LongType, nullable = false), StructField("tinyintType", ByteType, nullable = false), - StructField("decimalType", DecimalType.Unlimited, nullable = false), + StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false), StructField("fixedDecimalType", DecimalType(5, 1), nullable = false), StructField("binaryType", BinaryType, nullable = false), StructField("booleanType", BooleanType, nullable = false), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 2c916f3322b6d..143aadc08b1c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -202,7 +202,7 @@ class TableScanSuite extends DataSourceTest { StructField("longField_:,<>=+/~^", LongType, true) :: StructField("floatField", FloatType, true) :: StructField("doubleField", DoubleType, true) :: - StructField("decimalField1", DecimalType.Unlimited, true) :: + StructField("decimalField1", DecimalType.USER_DEFAULT, true) :: StructField("decimalField2", DecimalType(9, 2), true) :: StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index a8f2ee37cb8ed..592cfa0ee8380 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -179,7 +179,7 @@ private[hive] trait HiveInspectors { // writable case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType @@ -195,8 +195,8 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.SYSTEM_DEFAULT + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.SYSTEM_DEFAULT case c: Class[_] if c == classOf[Array[Byte]] => BinaryType case c: Class[_] if c == classOf[java.lang.Short] => ShortType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType @@ -813,9 +813,6 @@ private[hive] trait HiveInspectors { private def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) - case _ => new DecimalTypeInfo( - HiveShim.UNLIMITED_DECIMAL_PRECISION, - HiveShim.UNLIMITED_DECIMAL_SCALE) } def toTypeInfo: TypeInfo = dt match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 8518e333e8058..620b8a44d8a9b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -377,7 +377,7 @@ private[hive] object HiveQl extends Logging { DecimalType(precision.getText.toInt, scale.getText.toInt) case Token("TOK_DECIMAL", precision :: Nil) => DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited + case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType case Token("TOK_TINYINT", Nil) => ByteType @@ -1369,7 +1369,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.Unlimited) + Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => From 52de3acca4ce8c36fd4c9ce162473a091701bbc7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 23 Jul 2015 18:53:07 -0700 Subject: [PATCH 035/219] [SPARK-9122] [MLLIB] [PySpark] spark.mllib regression support batch predict spark.mllib support batch predict for LinearRegressionModel, RidgeRegressionModel and LassoModel. Author: Yanbo Liang Closes #7614 from yanboliang/spark-9122 and squashes the following commits: 4e610c0 [Yanbo Liang] spark.mllib regression support batch predict --- python/pyspark/mllib/regression.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 8e90adee5f4c2..5b7afc15ddfba 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -97,9 +97,11 @@ class LinearRegressionModelBase(LinearModel): def predict(self, x): """ - Predict the value of the dependent variable given a vector x - containing values for the independent variables. + Predict the value of the dependent variable given a vector or + an RDD of vectors containing values for the independent variables. """ + if isinstance(x, RDD): + return x.map(self.predict) x = _convert_to_vector(x) return self.weights.dot(x) + self.intercept @@ -124,6 +126,8 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) @@ -267,6 +271,8 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) @@ -382,6 +388,8 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) From d249636e59fabd8ca57a47dc2cbad9c4a4e7a750 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 23 Jul 2015 20:06:54 -0700 Subject: [PATCH 036/219] [SPARK-9216] [STREAMING] Define KinesisBackedBlockRDDs For more information see master JIRA: https://issues.apache.org/jira/browse/SPARK-9215 Design Doc: https://docs.google.com/document/d/1k0dl270EnK7uExrsCE7jYw7PYx0YC935uBcxn3p0f58/edit Author: Tathagata Das Closes #7578 from tdas/kinesis-rdd and squashes the following commits: 543d208 [Tathagata Das] Fixed scala style 5082a30 [Tathagata Das] Fixed scala style 3f40c2d [Tathagata Das] Addressed comments c4f25d2 [Tathagata Das] Addressed comment d3d64d1 [Tathagata Das] Minor update f6e35c8 [Tathagata Das] Added retry logic to make it more robust 8874b70 [Tathagata Das] Updated Kinesis RDD 575bdbc [Tathagata Das] Fix scala style issues 4a36096 [Tathagata Das] Add license 5da3995 [Tathagata Das] Changed KinesisSuiteHelper to KinesisFunSuite 528e206 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into kinesis-rdd 3ae0814 [Tathagata Das] Added KinesisBackedBlockRDD --- .../kinesis/KinesisBackedBlockRDD.scala | 285 ++++++++++++++++++ .../streaming/kinesis/KinesisTestUtils.scala | 2 +- .../kinesis/KinesisBackedBlockRDDSuite.scala | 246 +++++++++++++++ .../streaming/kinesis/KinesisFunSuite.scala | 13 +- .../kinesis/KinesisStreamSuite.scala | 4 +- 5 files changed, 545 insertions(+), 5 deletions(-) create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala new file mode 100644 index 0000000000000..8f144a4d974a8 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConversions._ +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark._ +import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.NextIterator + + +/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */ +private[kinesis] +case class SequenceNumberRange( + streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String) + +/** Class representing an array of Kinesis sequence number ranges */ +private[kinesis] +case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) { + def isEmpty(): Boolean = ranges.isEmpty + def nonEmpty(): Boolean = ranges.nonEmpty + override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")") +} + +private[kinesis] +object SequenceNumberRanges { + def apply(range: SequenceNumberRange): SequenceNumberRanges = { + new SequenceNumberRanges(Array(range)) + } +} + + +/** Partition storing the information of the ranges of Kinesis sequence numbers to read */ +private[kinesis] +class KinesisBackedBlockRDDPartition( + idx: Int, + blockId: BlockId, + val isBlockIdValid: Boolean, + val seqNumberRanges: SequenceNumberRanges + ) extends BlockRDDPartition(blockId, idx) + +/** + * A BlockRDD where the block data is backed by Kinesis, which can accessed using the + * sequence numbers of the corresponding blocks. + */ +private[kinesis] +class KinesisBackedBlockRDD( + sc: SparkContext, + regionId: String, + endpointUrl: String, + @transient blockIds: Array[BlockId], + @transient arrayOfseqNumberRanges: Array[SequenceNumberRanges], + @transient isBlockIdValid: Array[Boolean] = Array.empty, + retryTimeoutMs: Int = 10000, + awsCredentialsOption: Option[SerializableAWSCredentials] = None + ) extends BlockRDD[Array[Byte]](sc, blockIds) { + + require(blockIds.length == arrayOfseqNumberRanges.length, + "Number of blockIds is not equal to the number of sequence number ranges") + + override def isValid(): Boolean = true + + override def getPartitions: Array[Partition] = { + Array.tabulate(blockIds.length) { i => + val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) + new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i)) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val blockManager = SparkEnv.get.blockManager + val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition] + val blockId = partition.blockId + + def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = { + logDebug(s"Read partition data of $this from block manager, block $blockId") + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]]) + } + + def getBlockFromKinesis(): Iterator[Array[Byte]] = { + val credenentials = awsCredentialsOption.getOrElse { + new DefaultAWSCredentialsProviderChain().getCredentials() + } + partition.seqNumberRanges.ranges.iterator.flatMap { range => + new KinesisSequenceRangeIterator( + credenentials, endpointUrl, regionId, range, retryTimeoutMs) + } + } + if (partition.isBlockIdValid) { + getBlockFromBlockManager().getOrElse { getBlockFromKinesis() } + } else { + getBlockFromKinesis() + } + } +} + + +/** + * An iterator that return the Kinesis data based on the given range of sequence numbers. + * Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber, + * until the endSequenceNumber is reached. + */ +private[kinesis] +class KinesisSequenceRangeIterator( + credentials: AWSCredentials, + endpointUrl: String, + regionId: String, + range: SequenceNumberRange, + retryTimeoutMs: Int + ) extends NextIterator[Array[Byte]] with Logging { + + private val client = new AmazonKinesisClient(credentials) + private val streamName = range.streamName + private val shardId = range.shardId + + private var toSeqNumberReceived = false + private var lastSeqNumber: String = null + private var internalIterator: Iterator[Record] = null + + client.setEndpoint(endpointUrl, "kinesis", regionId) + + override protected def getNext(): Array[Byte] = { + var nextBytes: Array[Byte] = null + if (toSeqNumberReceived) { + finished = true + } else { + + if (internalIterator == null) { + + // If the internal iterator has not been initialized, + // then fetch records from starting sequence number + internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber) + } else if (!internalIterator.hasNext) { + + // If the internal iterator does not have any more records, + // then fetch more records after the last consumed sequence number + internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + } + + if (!internalIterator.hasNext) { + + // If the internal iterator still does not have any data, then throw exception + // and terminate this iterator + finished = true + throw new SparkException( + s"Could not read until the end sequence number of the range: $range") + } else { + + // Get the record, copy the data into a byte array and remember its sequence number + val nextRecord: Record = internalIterator.next() + val byteBuffer = nextRecord.getData() + nextBytes = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(nextBytes) + lastSeqNumber = nextRecord.getSequenceNumber() + + // If the this record's sequence number matches the stopping sequence number, then make sure + // the iterator is marked finished next time getNext() is called + if (nextRecord.getSequenceNumber == range.toSeqNumber) { + toSeqNumberReceived = true + } + } + + } + nextBytes + } + + override protected def close(): Unit = { + client.shutdown() + } + + /** + * Get records starting from or after the given sequence number. + */ + private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = { + val shardIterator = getKinesisIterator(iteratorType, seqNum) + val result = getRecordsAndNextKinesisIterator(shardIterator) + result._1 + } + + /** + * Get the records starting from using a Kinesis shard iterator (which is a progress handle + * to get records from Kinesis), and get the next shard iterator for next consumption. + */ + private def getRecordsAndNextKinesisIterator( + shardIterator: String): (Iterator[Record], String) = { + val getRecordsRequest = new GetRecordsRequest + getRecordsRequest.setRequestCredentials(credentials) + getRecordsRequest.setShardIterator(shardIterator) + val getRecordsResult = retryOrTimeout[GetRecordsResult]( + s"getting records using shard iterator") { + client.getRecords(getRecordsRequest) + } + (getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator) + } + + /** + * Get the Kinesis shard iterator for getting records starting from or after the given + * sequence number. + */ + private def getKinesisIterator( + iteratorType: ShardIteratorType, + sequenceNumber: String): String = { + val getShardIteratorRequest = new GetShardIteratorRequest + getShardIteratorRequest.setRequestCredentials(credentials) + getShardIteratorRequest.setStreamName(streamName) + getShardIteratorRequest.setShardId(shardId) + getShardIteratorRequest.setShardIteratorType(iteratorType.toString) + getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber) + val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult]( + s"getting shard iterator from sequence number $sequenceNumber") { + client.getShardIterator(getShardIteratorRequest) + } + getShardIteratorResult.getShardIterator + } + + /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ + private def retryOrTimeout[T](message: String)(body: => T): T = { + import KinesisSequenceRangeIterator._ + + var startTimeMs = System.currentTimeMillis() + var retryCount = 0 + var waitTimeMs = MIN_RETRY_WAIT_TIME_MS + var result: Option[T] = None + var lastError: Throwable = null + + def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs + def isMaxRetryDone = retryCount >= MAX_RETRIES + + while (result.isEmpty && !isTimedOut && !isMaxRetryDone) { + if (retryCount > 0) { // wait only if this is a retry + Thread.sleep(waitTimeMs) + waitTimeMs *= 2 // if you have waited, then double wait time for next round + } + try { + result = Some(body) + } catch { + case NonFatal(t) => + lastError = t + t match { + case ptee: ProvisionedThroughputExceededException => + logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee) + case e: Throwable => + throw new SparkException(s"Error while $message", e) + } + } + retryCount += 1 + } + result.getOrElse { + if (isTimedOut) { + throw new SparkException( + s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError) + } else { + throw new SparkException( + s"Gave up after $retryCount retries while $message, last exception: ", lastError) + } + } + } +} + +private[streaming] +object KinesisSequenceRangeIterator { + val MAX_RETRIES = 3 + val MIN_RETRY_WAIT_TIME_MS = 100 +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index f6bf552e6bb8e..0ff1b7ed0fd90 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -177,7 +177,7 @@ private class KinesisTestUtils( private[kinesis] object KinesisTestUtils { - val envVarName = "RUN_KINESIS_TESTS" + val envVarName = "ENABLE_KINESIS_TESTS" val shouldRunTests = sys.env.get(envVarName) == Some("1") diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala new file mode 100644 index 0000000000000..b2e2a4246dbd5 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} + +class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { + + private val regionId = "us-east-1" + private val endpointUrl = "https://kinesis.us-east-1.amazonaws.com" + private val testData = 1 to 8 + + private var testUtils: KinesisTestUtils = null + private var shardIds: Seq[String] = null + private var shardIdToData: Map[String, Seq[Int]] = null + private var shardIdToSeqNumbers: Map[String, Seq[String]] = null + private var shardIdToDataAndSeqNumbers: Map[String, Seq[(Int, String)]] = null + private var shardIdToRange: Map[String, SequenceNumberRange] = null + private var allRanges: Seq[SequenceNumberRange] = null + + private var sc: SparkContext = null + private var blockManager: BlockManager = null + + + override def beforeAll(): Unit = { + runIfTestsEnabled("Prepare KinesisTestUtils") { + testUtils = new KinesisTestUtils(endpointUrl) + testUtils.createStream() + + shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") + + shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq + shardIdToData = shardIdToDataAndSeqNumbers.mapValues { _.map { _._1 }} + shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }} + shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) => + val seqNumRange = SequenceNumberRange( + testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last) + (shardId, seqNumRange) + } + allRanges = shardIdToRange.values.toSeq + + val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite") + sc = new SparkContext(conf) + blockManager = sc.env.blockManager + } + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + + testIfEnabled("Basic reading from Kinesis") { + // Verify all data using multiple ranges in a single RDD partition + val receivedData1 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(1), + Array(SequenceNumberRanges(allRanges.toArray)) + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData1.toSet === testData.toSet) + + // Verify all data using one range in each of the multiple RDD partitions + val receivedData2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData2.toSet === testData.toSet) + + // Verify ordering within each partition + val receivedData3 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collectPartitions() + assert(receivedData3.length === allRanges.size) + for (i <- 0 until allRanges.size) { + assert(receivedData3(i).toSeq === shardIdToData(allRanges(i).shardId)) + } + } + + testIfEnabled("Read data available in both block manager and Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available only in block manager, not in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0) + } + + testIfEnabled("Read data available only in Kinesis, not in block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 0, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available partially in block manager, rest in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 1, numPartitionsInKinesis = 1) + } + + testIfEnabled("Test isBlockValid skips block fetching from block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0, + testIsBlockValid = true) + } + + testIfEnabled("Test whether RDD is valid after removing blocks from block anager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2, + testBlockRemove = true) + } + + /** + * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager + * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * It can also test if the partitions that were read from the log were again stored in + * block manager. + * + * + * + * @param numPartitions Number of partitions in RDD + * @param numPartitionsInBM Number of partitions to write to the BlockManager. + * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager + * @param numPartitionsInKinesis Number of partitions to write to the Kinesis. + * Partitions (numPartitions - 1 - numPartitionsInKinesis) to + * (numPartitions - 1) will be written to Kinesis + * @param testIsBlockValid Test whether setting isBlockValid to false skips block fetching + * @param testBlockRemove Test whether calling rdd.removeBlock() makes the RDD still usable with + * reads falling back to the WAL + * Example with numPartitions = 5, numPartitionsInBM = 3, and numPartitionsInWAL = 4 + * + * numPartitionsInBM = 3 + * |------------------| + * | | + * 0 1 2 3 4 + * | | + * |-------------------------| + * numPartitionsInKinesis = 4 + */ + private def testRDD( + numPartitions: Int, + numPartitionsInBM: Int, + numPartitionsInKinesis: Int, + testIsBlockValid: Boolean = false, + testBlockRemove: Boolean = false + ): Unit = { + require(shardIds.size > 1, "Need at least 2 shards to test") + require(numPartitionsInBM <= shardIds.size , + "Number of partitions in BlockManager cannot be more than the Kinesis test shards available") + require(numPartitionsInKinesis <= shardIds.size , + "Number of partitions in Kinesis cannot be more than the Kinesis test shards available") + require(numPartitionsInBM <= numPartitions, + "Number of partitions in BlockManager cannot be more than that in RDD") + require(numPartitionsInKinesis <= numPartitions, + "Number of partitions in Kinesis cannot be more than that in RDD") + + // Put necessary blocks in the block manager + val blockIds = fakeBlockIds(numPartitions) + blockIds.foreach(blockManager.removeBlock(_)) + (0 until numPartitionsInBM).foreach { i => + val blockData = shardIdToData(shardIds(i)).iterator.map { _.toString.getBytes() } + blockManager.putIterator(blockIds(i), blockData, StorageLevel.MEMORY_ONLY) + } + + // Create the necessary ranges to use in the RDD + val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)( + SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))) + val realRanges = Array.tabulate(numPartitionsInKinesis) { i => + val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) + SequenceNumberRanges(Array(range)) + } + val ranges = (fakeRanges ++ realRanges) + + + // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not + require( + blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty), + "Expected blocks not in BlockManager" + ) + + require( + blockIds.drop(numPartitionsInBM).forall(blockManager.get(_).isEmpty), + "Unexpected blocks in BlockManager" + ) + + // Make sure that the right sequence `numPartitionsInKinesis` are configured, and others are not + require( + ranges.takeRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName == testUtils.streamName } + }, "Incorrect configuration of RDD, expected ranges not set: " + ) + + require( + ranges.dropRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName != testUtils.streamName } + }, "Incorrect configuration of RDD, unexpected ranges set" + ) + + val rdd = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds, ranges) + val collectedData = rdd.map { bytes => + new String(bytes).toInt + }.collect() + assert(collectedData.toSet === testData.toSet) + + // Verify that the block fetching is skipped when isBlockValid is set to false. + // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // Using that RDD will throw exception, as it skips block fetching even if the blocks are in + // in BlockManager. + if (testIsBlockValid) { + require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") + require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") + val rdd2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds.toArray, + ranges, isBlockIdValid = Array.fill(blockIds.length)(false)) + intercept[SparkException] { + rdd2.collect() + } + } + + // Verify that the RDD is not invalid after the blocks are removed and can still read data + // from write ahead log + if (testBlockRemove) { + require(numPartitions === numPartitionsInKinesis, + "All partitions must be in WAL for this test") + require(numPartitionsInBM > 0, "Some partitions must be in BlockManager for this test") + rdd.removeBlocks() + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSet === testData.toSet) + } + } + + /** Generate fake block ids */ + private def fakeBlockIds(num: Int): Array[BlockId] = { + Array.tabulate(num) { i => new StreamBlockId(0, i) } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala index 6d011f295e7f7..8373138785a89 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -23,15 +23,24 @@ import org.apache.spark.SparkFunSuite * Helper class that runs Kinesis real data transfer tests or * ignores them based on env variable is set or not. */ -trait KinesisSuiteHelper { self: SparkFunSuite => +trait KinesisFunSuite extends SparkFunSuite { import KinesisTestUtils._ /** Run the test if environment variable is set or ignore the test */ - def testOrIgnore(testName: String)(testBody: => Unit) { + def testIfEnabled(testName: String)(testBody: => Unit) { if (shouldRunTests) { test(testName)(testBody) } else { ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody) } } + + /** Run the give body of code only if Kinesis tests are enabled */ + def runIfTestsEnabled(message: String)(body: => Unit): Unit = { + if (shouldRunTests) { + body + } else { + ignore(s"$message [enable by setting env var $envVarName=1]")() + } + } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 50f71413abf37..f9c952b9468bb 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper +class KinesisStreamSuite extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { // This is the name that KCL uses to save metadata to DynamoDB @@ -83,7 +83,7 @@ class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper * you must have AWS credentials available through the default AWS provider chain, * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ - testOrIgnore("basic operation") { + testIfEnabled("basic operation") { val kinesisTestUtils = new KinesisTestUtils() try { kinesisTestUtils.createStream() From d4d762f275749a923356cd84de549b14c22cc3eb Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Thu, 23 Jul 2015 22:35:41 -0700 Subject: [PATCH 037/219] [SPARK-8092] [ML] Allow OneVsRest Classifier feature and label column names to be configurable. The base classifier input and output columns are ignored in favor of the ones specified in OneVsRest. Author: Ram Sriharsha Closes #6631 from harsha2010/SPARK-8092 and squashes the following commits: 6591dc6 [Ram Sriharsha] add documentation for params b7024b1 [Ram Sriharsha] cleanup f0e2bfb [Ram Sriharsha] merge with master 108d3d7 [Ram Sriharsha] merge with master 4f74126 [Ram Sriharsha] Allow label/ features columns to be configurable --- .../spark/ml/classification/OneVsRest.scala | 17 ++++++++++++- .../ml/classification/OneVsRestSuite.scala | 24 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index ea757c5e40c76..1741f19dc911c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams { /** * param for the base binary classifier that we reduce multiclass classification into. + * The base classifier input and output columns are ignored in favor of + * the ones specified in [[OneVsRest]]. * @group param */ val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier") @@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String) set(classifier, value.asInstanceOf[ClassifierType]) } + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } @@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String) val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier - classifier.fit(trainingDataset, classifier.labelCol -> labelColName) + val paramMap = new ParamMap() + paramMap.put(classifier.labelCol -> labelColName) + paramMap.put(classifier.featuresCol -> getFeaturesCol) + paramMap.put(classifier.predictionCol -> getPredictionCol) + classifier.fit(trainingDataset, paramMap) }.toArray[ClassificationModel[_, _]] if (handlePersistence) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 75cf5bd4ead4f..3775292f6dca7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS @@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { ova.fit(datasetWithLabelMetadata) } + test("SPARK-8092: ensure label features and prediction cols are configurable") { + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexed") + + val indexedDataset = labelIndexer + .fit(dataset) + .transform(dataset) + .drop("label") + .withColumnRenamed("features", "f") + + val ova = new OneVsRest() + ova.setClassifier(new LogisticRegression()) + .setLabelCol(labelIndexer.getOutputCol) + .setFeaturesCol("f") + .setPredictionCol("p") + + val ovaModel = ova.fit(indexedDataset) + val transformedDataset = ovaModel.transform(indexedDataset) + val outputFields = transformedDataset.schema.fieldNames.toSet + assert(outputFields.contains("p")) + } + test("SPARK-8049: OneVsRest shouldn't output temp columns") { val logReg = new LogisticRegression() .setMaxIter(1) From 408e64b284ef8bd6796d815b5eb603312d090b74 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 23 Jul 2015 23:40:01 -0700 Subject: [PATCH 038/219] [SPARK-9294][SQL] cleanup comments, code style, naming typo for the new aggregation fix some comments and code style for https://github.com/apache/spark/pull/7458 Author: Wenchen Fan Closes #7619 from cloud-fan/agg-clean and squashes the following commits: 3925457 [Wenchen Fan] one more... cc78357 [Wenchen Fan] one more cleanup 26f6a93 [Wenchen Fan] some minor cleanup for the new aggregation --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../expressions/aggregate/interfaces.scala | 18 ++-- .../apache/spark/sql/execution/Exchange.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 8 +- .../aggregate/sortBasedIterators.scala | 82 ++++++------------- .../spark/sql/execution/aggregate/utils.scala | 10 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 9 +- 7 files changed, 46 insertions(+), 89 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8cadbc57e87e1..e916887187dc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -533,7 +533,7 @@ class Analyzer( case min: Min if isDistinct => min // For other aggregate functions, DISTINCT keyword is not supported for now. // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other if isDistinct => + case other: AggregateExpression1 if isDistinct => failAnalysis(s"$name does not support DISTINCT keyword.") // If it does not have DISTINCT keyword, we will return it as is. case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index d3fee1ade05e6..10bd19c8a840f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCod import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** The mode of an [[AggregateFunction1]]. */ +/** The mode of an [[AggregateFunction2]]. */ private[sql] sealed trait AggregateMode /** - * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation. + * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ private[sql] case object Partial extends AggregateMode /** - * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. @@ -42,8 +42,8 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers - * containing intermediate results for this function and the generate final result. + * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. */ @@ -85,12 +85,12 @@ private[sql] case class AggregateExpression2( override def nullable: Boolean = aggregateFunction.nullable override def references: AttributeSet = { - val childReferemces = mode match { + val childReferences = mode match { case Partial | Complete => aggregateFunction.references.toSeq case PartialMerge | Final => aggregateFunction.bufferAttributes } - AttributeSet(childReferemces) + AttributeSet(childReferences) } override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)" @@ -99,10 +99,8 @@ private[sql] case class AggregateExpression2( abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes { - self: Product => - /** An aggregate function is not foldable. */ - override def foldable: Boolean = false + final override def foldable: Boolean = false /** * The offset of this function's buffer in the underlying buffer shared with other functions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index d31e265a293e9..41a0c519ba527 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -224,13 +224,13 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // compatible. // TODO: ASSUMES TRANSITIVITY? def compatible: Boolean = - !operator.children + operator.children .map(_.outputPartitioning) .sliding(2) - .map { + .forall { case Seq(a) => true case Seq(a, b) => a.compatibleWith(b) - }.exists(!_) + } // Adds Exchange or Sort operators as required def addOperatorsIfNecessary( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f54aa2027f6a6..eb4be1900b153 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -190,12 +190,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sqlContext.conf.codegenEnabled).isDefined } - def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists { - case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false + def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { + case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && - Seq(IntegerType, LongType).contains(exprs.head.dataType) => false - case _ => true + Seq(IntegerType, LongType).contains(exprs.head.dataType) => true + case _ => false } def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala index ce1cbdc9cb090..b8e95a5a2a4da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -67,13 +67,6 @@ private[sql] abstract class SortAggregationIterator( functions } - // All non-algebraic aggregate functions. - protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { - aggregateFunctions.collect { - case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - }.toArray - } - // Positions of those non-algebraic aggregate functions in aggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and // func2 and func3 are non-algebraic aggregate functions. @@ -91,6 +84,10 @@ private[sql] abstract class SortAggregationIterator( positions.toArray } + // All non-algebraic aggregate functions. + protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = + nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions) + // This is used to project expressions for the grouping expressions. protected val groupGenerator = newMutableProjection(groupingExpressions, inputAttributes)() @@ -179,8 +176,6 @@ private[sql] abstract class SortAggregationIterator( // For the below compare method, we do not need to make a copy of groupingKey. val groupingKey = groupGenerator(currentRow) // Check if the current row belongs the current input row. - currentGroupingKey.equals(groupingKey) - if (currentGroupingKey == groupingKey) { processRow(currentRow) } else { @@ -288,10 +283,7 @@ class PartialSortAggregationIterator( // This projection is used to update buffer values for all AlgebraicAggregates. private val algebraicUpdateProjection = { - val bufferSchema = aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } + val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes) val updateExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.updateExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) @@ -348,19 +340,14 @@ class PartialMergeSortAggregationIterator( inputAttributes, inputIter) { - private val placeholderAttribtues = + private val placeholderAttributes = Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { val bufferSchemata = - placeholderAttribtues ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ placeholderAttribtues ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + placeholderAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ + placeholderAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) @@ -444,13 +431,8 @@ class FinalSortAggregationIterator( // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ + offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) @@ -462,13 +444,8 @@ class FinalSortAggregationIterator( // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ + offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp @@ -599,11 +576,10 @@ class FinalAndCompleteSortAggregationIterator( } // All non-algebraic aggregate functions with mode Final. - private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = finalAggregateFunctions.collect { case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - }.toArray - } + } // All aggregate functions with mode Complete. private val completeAggregateFunctions: Array[AggregateFunction2] = { @@ -617,11 +593,10 @@ class FinalAndCompleteSortAggregationIterator( } // All non-algebraic aggregate functions with mode Complete. - private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = completeAggregateFunctions.collect { case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func - }.toArray - } + } // This projection is used to merge buffer values for all AlgebraicAggregates with mode // Final. @@ -633,13 +608,9 @@ class FinalAndCompleteSortAggregationIterator( val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp) val bufferSchemata = - offsetAttributes ++ finalAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } ++ completeOffsetAttributes + offsetAttributes ++ finalAggregateFunctions.flatMap(_.bufferAttributes) ++ + completeOffsetAttributes ++ offsetAttributes ++ + finalAggregateFunctions.flatMap(_.cloneBufferAttributes) ++ completeOffsetAttributes val mergeExpressions = placeholderExpressions ++ finalAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions @@ -658,10 +629,8 @@ class FinalAndCompleteSortAggregationIterator( val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp) val bufferSchema = - offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } + offsetAttributes ++ finalOffsetAttributes ++ + completeAggregateFunctions.flatMap(_.bufferAttributes) val updateExpressions = placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.updateExpressions @@ -673,13 +642,8 @@ class FinalAndCompleteSortAggregationIterator( // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.bufferAttributes - case agg: AggregateFunction2 => agg.bufferAttributes - } ++ offsetAttributes ++ aggregateFunctions.flatMap { - case ae: AlgebraicAggregate => ae.cloneBufferAttributes - case agg: AggregateFunction2 => agg.cloneBufferAttributes - } + offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ + offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 1cb27710e0480..5bbe6c162ff4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -191,10 +191,7 @@ object Utils { } val groupExpressionMap = namedGroupingExpressions.toMap val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - val partialAggregateExpressions = aggregateExpressions.map { - case AggregateExpression2(aggregateFunction, mode, isDistinct) => - AggregateExpression2(aggregateFunction, Partial, isDistinct) - } + val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => agg.aggregateFunction.bufferAttributes } @@ -208,10 +205,7 @@ object Utils { child) // 2. Create an Aggregate Operator for final aggregations. - val finalAggregateExpressions = aggregateExpressions.map { - case AggregateExpression2(aggregateFunction, mode, isDistinct) => - AggregateExpression2(aggregateFunction, Final, isDistinct) - } + val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) val finalAggregateAttributes = finalAggregateExpressions.map { expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ab8dce603c117..95a1106cf072d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1518,18 +1518,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-8945: add and subtract expressions for interval type") { import org.apache.spark.unsafe.types.Interval + import org.apache.spark.unsafe.types.Interval.MICROS_PER_WEEK val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") - checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123))) + checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) checkAnswer(df.select(df("i") + new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123))) + Row(new Interval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123))) checkAnswer(df.select(df("i") - new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123))) + Row(new Interval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123))) // unary minus checkAnswer(df.select(-df("i")), - Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))) + Row(new Interval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } } From cb8c241f05b9ab4ad0cd07df14d454cc5a4554cc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 Jul 2015 01:18:43 -0700 Subject: [PATCH 039/219] [SPARK-9200][SQL] Don't implicitly cast non-atomic types to string type. Author: Reynold Xin Closes #7636 from rxin/complex-string-implicit-cast and squashes the following commits: 3e67327 [Reynold Xin] [SPARK-9200][SQL] Don't implicitly cast non-atomic types to string type. --- .../spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 3 ++- .../sql/catalyst/analysis/HiveTypeCoercionSuite.scala | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index d56ceeadc9e85..87ffbfe791b93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -720,7 +720,8 @@ object HiveTypeCoercion { case (StringType, DateType) => Cast(e, DateType) case (StringType, TimestampType) => Cast(e, TimestampType) case (StringType, BinaryType) => Cast(e, BinaryType) - case (any, StringType) if any != StringType => Cast(e, StringType) + // Cast any atomic type to string. + case (any: AtomicType, StringType) if any != StringType => Cast(e, StringType) // When we reach here, input type is not acceptable for any types in this type collection, // try to find the first one we can implicitly cast. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index d0fb95b580ad2..55865bdb534b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -115,6 +115,14 @@ class HiveTypeCoercionSuite extends PlanTest { shouldNotCast(IntegerType, ArrayType) shouldNotCast(IntegerType, MapType) shouldNotCast(IntegerType, StructType) + + shouldNotCast(IntervalType, StringType) + + // Don't implicitly cast complex types to string. + shouldNotCast(ArrayType(StringType), StringType) + shouldNotCast(MapType(StringType, StringType), StringType) + shouldNotCast(new StructType().add("a1", StringType), StringType) + shouldNotCast(MapType(StringType, StringType), StringType) } test("tightest common bound for types") { From 8fe32b4f7d49607ad5f2479d454b33ab3f079f7c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 Jul 2015 01:47:13 -0700 Subject: [PATCH 040/219] [build] Enable memory leak detection for Tungsten. This was turned off accidentally in #7591. Author: Reynold Xin Closes #7637 from rxin/enable-mem-leak-detect and squashes the following commits: 34bc3ef [Reynold Xin] Enable memory leak detection for Tungsten. --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b5b0adf630b9e..61a05d375d99e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -543,7 +543,7 @@ object TestSettings { javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", - //javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", + javaOptions in Test += "-Dspark.unsafe.exceptionOnMemoryLeak=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test += "-Dderby.system.durability=test", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") From 6a7e537f3a4fd5e99a905f9842dc0ad4c348e4fd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Jul 2015 17:39:57 +0800 Subject: [PATCH 041/219] [SPARK-8756] [SQL] Keep cached information and avoid re-calculating footers in ParquetRelation2 JIRA: https://issues.apache.org/jira/browse/SPARK-8756 Currently, in ParquetRelation2, footers are re-read every time refresh() is called. But we can check if it is possibly changed before we do the reading because reading all footers will be expensive when there are too many partitions. This pr fixes this by keeping some cached information to check it. Author: Liang-Chi Hsieh Closes #7154 from viirya/cached_footer_parquet_relation and squashes the following commits: 92e9347 [Liang-Chi Hsieh] Fix indentation. ae0ec64 [Liang-Chi Hsieh] Fix wrong assignment. c8fdfb7 [Liang-Chi Hsieh] Fix it. a52b6d1 [Liang-Chi Hsieh] For comments. c2a2420 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cached_footer_parquet_relation fa5458f [Liang-Chi Hsieh] Use Map to cache FileStatus and do merging previously loaded schema and newly loaded one. 6ae0911 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cached_footer_parquet_relation 21bbdec [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cached_footer_parquet_relation 12a0ed9 [Liang-Chi Hsieh] Add check of FileStatus's modification time. 186429d [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cached_footer_parquet_relation 0ef8caf [Liang-Chi Hsieh] Keep cached information and avoid re-calculating footers. --- .../apache/spark/sql/parquet/newParquet.scala | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 2f9f880c70690..c384697c0ee62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -345,24 +345,34 @@ private[sql] class ParquetRelation2( // Schema of the whole table, including partition columns. var schema: StructType = _ + // Cached leaves + var cachedLeaves: Set[FileStatus] = null + /** * Refreshes `FileStatus`es, footers, partition spec, and table schema. */ def refresh(): Unit = { - // Lists `FileStatus`es of all leaf nodes (files) under all base directories. - val leaves = cachedLeafStatuses().filter { f => - isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - }.toArray - - dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) - metadataStatuses = leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) - commonMetadataStatuses = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) - - // If we already get the schema, don't need to re-compute it since the schema merging is - // time-consuming. - if (dataSchema == null) { + val currentLeafStatuses = cachedLeafStatuses() + + // Check if cachedLeafStatuses is changed or not + val leafStatusesChanged = (cachedLeaves == null) || + !cachedLeaves.equals(currentLeafStatuses) + + if (leafStatusesChanged) { + cachedLeaves = currentLeafStatuses.toIterator.toSet + + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. + val leaves = currentLeafStatuses.filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + }.toArray + + dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) + metadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + commonMetadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + dataSchema = { val dataSchema0 = maybeDataSchema .orElse(readSchema()) From 6cd28cc21ed585ab8d1e0e7147a1a48b044c9c8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Fri, 24 Jul 2015 15:41:13 +0100 Subject: [PATCH 042/219] [SPARK-9236] [CORE] Make defaultPartitioner not reuse a parent RDD's partitioner if it has 0 partitions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See also comments on https://issues.apache.org/jira/browse/SPARK-9236 Author: François Garillot Closes #7616 from huitseeker/issue/SPARK-9236 and squashes the following commits: 217f902 [François Garillot] [SPARK-9236] Make defaultPartitioner not reuse a parent RDD's partitioner if it has 0 partitions --- .../scala/org/apache/spark/Partitioner.scala | 2 +- .../spark/rdd/PairRDDFunctionsSuite.scala | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index ad68512dccb79..4b9d59975bdc2 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -56,7 +56,7 @@ object Partitioner { */ def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = { val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse - for (r <- bySize if r.partitioner.isDefined) { + for (r <- bySize if r.partitioner.isDefined && r.partitioner.get.numPartitions > 0) { return r.partitioner.get } if (rdd.context.conf.contains("spark.default.parallelism")) { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index dfa102f432a02..1321ec84735b5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -282,6 +282,29 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { )) } + // See SPARK-9326 + test("cogroup with empty RDD") { + import scala.reflect.classTag + val intPairCT = classTag[(Int, Int)] + + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.emptyRDD[(Int, Int)](intPairCT) + + val joined = rdd1.cogroup(rdd2).collect() + assert(joined.size > 0) + } + + // See SPARK-9326 + test("cogroup with groupByed RDD having 0 partitions") { + import scala.reflect.classTag + val intCT = classTag[Int] + + val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) + val rdd2 = sc.emptyRDD[Int](intCT).groupBy((x) => 5) + val joined = rdd1.cogroup(rdd2).collect() + assert(joined.size > 0) + } + test("rightOuterJoin") { val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1))) val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) From dfb18be0366376be3b928dbf4570448c60fe652b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 24 Jul 2015 08:24:13 -0700 Subject: [PATCH 043/219] [SPARK-9069] [SQL] follow up Address comments for #7605 cc rxin Author: Davies Liu Closes #7634 from davies/decimal_unlimited2 and squashes the following commits: b2d8b0d [Davies Liu] add doc and test for DecimalType.isWiderThan 65b251c [Davies Liu] fix test 6a91f32 [Davies Liu] fix style ca9c973 [Davies Liu] address comments --- .../catalyst/analysis/HiveTypeCoercion.scala | 30 +++++-------------- .../expressions/decimalFunctions.scala | 13 ++++++++ .../apache/spark/sql/types/DecimalType.scala | 6 +++- .../analysis/DecimalPrecisionSuite.scala | 26 ++++++++++++++++ .../analysis/HiveTypeCoercionSuite.scala | 6 ++-- 5 files changed, 55 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 87ffbfe791b93..e0527503442f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -38,7 +36,7 @@ object HiveTypeCoercion { val typeCoercionRules = PropagateTypes :: InConversion :: - WidenTypes :: + WidenSetOperationTypes :: PromoteStrings :: DecimalPrecision :: BooleanEquality :: @@ -175,7 +173,7 @@ object HiveTypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - object WidenTypes extends Rule[LogicalPlan] { + object WidenSetOperationTypes extends Rule[LogicalPlan] { private[this] def widenOutputTypes( planName: String, @@ -203,9 +201,9 @@ object HiveTypeCoercion { def castOutput(plan: LogicalPlan): LogicalPlan = { val casted = plan.output.zip(castedTypes).map { - case (hs, Some(dt)) if dt != hs.dataType => - Alias(Cast(hs, dt), hs.name)() - case (hs, _) => hs + case (e, Some(dt)) if e.dataType != dt => + Alias(Cast(e, dt), e.name)() + case (e, _) => e } Project(casted, plan) } @@ -355,20 +353,8 @@ object HiveTypeCoercion { DecimalType.bounded(range + scale, scale) } - /** - * An expression used to wrap the children when promote the precision of DecimalType to avoid - * promote multiple times. - */ - case class ChangePrecision(child: Expression) extends UnaryExpression { - override def dataType: DataType = child.dataType - override def eval(input: InternalRow): Any = child.eval(input) - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" - override def prettyName: String = "change_precision" - } - - def changePrecision(e: Expression, dataType: DataType): Expression = { - ChangePrecision(Cast(e, dataType)) + private def changePrecision(e: Expression, dataType: DataType): Expression = { + ChangeDecimalPrecision(Cast(e, dataType)) } def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -378,7 +364,7 @@ object HiveTypeCoercion { case e if !e.childrenResolved => e // Skip nodes who is already promoted - case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] => e + case e: BinaryArithmetic if e.left.isInstanceOf[ChangeDecimalPrecision] => e case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index b9d4736a65e26..adb33e4c8d4a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ @@ -60,3 +61,15 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un }) } } + +/** + * An expression used to wrap the children when promote the precision of DecimalType to avoid + * promote multiple times. + */ +case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression { + override def dataType: DataType = child.dataType + override def eval(input: InternalRow): Any = child.eval(input) + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" + override def prettyName: String = "change_decimal_precision" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 26b24616d98ec..0cd352d0fa928 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -78,6 +78,10 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType { override def toString: String = s"DecimalType($precision,$scale)" + /** + * Returns whether this DecimalType is wider than `other`. If yes, it means `other` + * can be casted into `this` safely without losing any precision or range. + */ private[sql] def isWiderThan(other: DataType): Boolean = other match { case dt: DecimalType => (precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale @@ -109,7 +113,7 @@ object DecimalType extends AbstractDataType { @deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5") val Unlimited: DecimalType = SYSTEM_DEFAULT - // The decimal types compatible with other numberic types + // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) private[sql] val ShortDecimal = DecimalType(5, 0) private[sql] val IntDecimal = DecimalType(10, 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index f9f15e7a6608d..fc11627da6fd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -154,4 +154,30 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { checkType(Remainder(expr, u), DoubleType) } } + + test("DecimalType.isWiderThan") { + val d0 = DecimalType(2, 0) + val d1 = DecimalType(2, 1) + val d2 = DecimalType(5, 2) + val d3 = DecimalType(15, 3) + val d4 = DecimalType(25, 4) + + assert(d0.isWiderThan(d1) === false) + assert(d1.isWiderThan(d0) === false) + assert(d1.isWiderThan(d2) === false) + assert(d2.isWiderThan(d1) === true) + assert(d2.isWiderThan(d3) === false) + assert(d3.isWiderThan(d2) === true) + assert(d4.isWiderThan(d3) === true) + + assert(d1.isWiderThan(ByteType) === false) + assert(d2.isWiderThan(ByteType) === true) + assert(d2.isWiderThan(ShortType) === false) + assert(d3.isWiderThan(ShortType) === true) + assert(d3.isWiderThan(IntegerType) === true) + assert(d3.isWiderThan(LongType) === false) + assert(d4.isWiderThan(LongType) === true) + assert(d4.isWiderThan(FloatType) === false) + assert(d4.isWiderThan(DoubleType) === false) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 55865bdb534b4..4454d51b75877 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -314,7 +314,7 @@ class HiveTypeCoercionSuite extends PlanTest { ) } - test("WidenTypes for union except and intersect") { + test("WidenSetOperationTypes for union except and intersect") { def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { logical.output.zip(expectTypes).foreach { case (attr, dt) => assert(attr.dataType === dt) @@ -332,7 +332,7 @@ class HiveTypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) - val wt = HiveTypeCoercion.WidenTypes + val wt = HiveTypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) val r1 = wt(Union(left, right)).asInstanceOf[Union] @@ -353,7 +353,7 @@ class HiveTypeCoercionSuite extends PlanTest { } } - val dp = HiveTypeCoercion.WidenTypes + val dp = HiveTypeCoercion.WidenSetOperationTypes val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) From 846cf46282da8f4b87aeee64e407a38cdc80e13b Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Fri, 24 Jul 2015 08:34:50 -0700 Subject: [PATCH 044/219] [SPARK-9238] [SQL] Remove two extra useless entries for bytesOfCodePointInUTF8 Only a trial thing, not sure if I understand correctly or not but I guess only 2 entries in `bytesOfCodePointInUTF8` for the case of 6 bytes codepoint(1111110x) is enough. Details can be found from https://en.wikipedia.org/wiki/UTF-8 in "Description" section. Author: zhichao.li Closes #7582 from zhichao-li/utf8 and squashes the following commits: 8bddd01 [zhichao.li] two extra entries --- .../src/main/java/org/apache/spark/unsafe/types/UTF8String.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 946d355f1fc28..6d8dcb1cbf876 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -48,7 +48,7 @@ public final class UTF8String implements Comparable, Serializable { 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, - 6, 6, 6, 6}; + 6, 6}; public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); From 428cde5d1c46adad344255447283dfb9716d65cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Fri, 24 Jul 2015 17:09:33 +0100 Subject: [PATCH 045/219] [SPARK-9250] Make change-scala-version more helpful w.r.t. valid Scala versions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Author: François Garillot Closes #7595 from huitseeker/issue/SPARK-9250 and squashes the following commits: 80a0218 [François Garillot] [SPARK-9250] Make change-scala-version's usage more explicit, introduce a -h|--help option. --- dev/change-scala-version.sh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dev/change-scala-version.sh b/dev/change-scala-version.sh index b81c00c9d6d9d..d7975dfb6475c 100755 --- a/dev/change-scala-version.sh +++ b/dev/change-scala-version.sh @@ -19,19 +19,23 @@ set -e +VALID_VERSIONS=( 2.10 2.11 ) + usage() { - echo "Usage: $(basename $0) " 1>&2 + echo "Usage: $(basename $0) [-h|--help] +where : + -h| --help Display this help text + valid version values : ${VALID_VERSIONS[*]} +" 1>&2 exit 1 } -if [ $# -ne 1 ]; then +if [[ ($# -ne 1) || ( $1 == "--help") || $1 == "-h" ]]; then usage fi TO_VERSION=$1 -VALID_VERSIONS=( 2.10 2.11 ) - check_scala_version() { for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done echo "Invalid Scala version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2 From 3aec9f4e2d8fcce9ddf84ab4d0e10147c18afa16 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 24 Jul 2015 09:10:11 -0700 Subject: [PATCH 046/219] [SPARK-9249] [SPARKR] local variable assigned but may not be used [[SPARK-9249] local variable assigned but may not be used - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9249) https://gist.github.com/yu-iskw/0e5b0253c11769457ea5 Author: Yu ISHIKAWA Closes #7640 from yu-iskw/SPARK-9249 and squashes the following commits: 7a51cab [Yu ISHIKAWA] [SPARK-9249][SparkR] local variable assigned but may not be used --- R/pkg/R/deserialize.R | 4 ++-- R/pkg/R/sparkR.R | 3 --- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 7d1f6b0819ed0..6d364f77be7ee 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -102,11 +102,11 @@ readList <- function(con) { readRaw <- function(con) { dataLen <- readInt(con) - data <- readBin(con, raw(), as.integer(dataLen), endian = "big") + readBin(con, raw(), as.integer(dataLen), endian = "big") } readRawLen <- function(con, dataLen) { - data <- readBin(con, raw(), as.integer(dataLen), endian = "big") + readBin(con, raw(), as.integer(dataLen), endian = "big") } readDeserialize <- function(con) { diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 79b79d70943cb..76c15875b50d5 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -104,16 +104,13 @@ sparkR.init <- function( return(get(".sparkRjsc", envir = .sparkREnv)) } - sparkMem <- Sys.getenv("SPARK_MEM", "1024m") jars <- suppressWarnings(normalizePath(as.character(sparkJars))) # Classpath separator is ";" on Windows # URI needs four /// as from http://stackoverflow.com/a/18522792 if (.Platform$OS.type == "unix") { - collapseChar <- ":" uriSep <- "//" } else { - collapseChar <- ";" uriSep <- "////" } From 431ca39be51352dfcdacc87de7e64c2af313558d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 Jul 2015 09:37:36 -0700 Subject: [PATCH 047/219] [SPARK-9285][SQL] Remove InternalRow's inheritance from Row. I also changed InternalRow's size/length function to numFields, to make it more obvious that it is not about bytes, but the number of fields. Author: Reynold Xin Closes #7626 from rxin/internalRow and squashes the following commits: e124daf [Reynold Xin] Fixed test case. 805ceb7 [Reynold Xin] Commented out the failed test suite. f8a9ca5 [Reynold Xin] Fixed more bugs. Still at least one more remaining. 76d9081 [Reynold Xin] Fixed data sources. 7807f70 [Reynold Xin] Fixed DataFrameSuite. cb60cd2 [Reynold Xin] Code review & small bug fixes. 0a2948b [Reynold Xin] Fixed style. 3280d03 [Reynold Xin] [SPARK-9285][SQL] Remove InternalRow's inheritance from Row. --- .../apache/spark/mllib/linalg/Matrices.scala | 4 +- .../apache/spark/mllib/linalg/Vectors.scala | 4 +- .../sql/catalyst/expressions/UnsafeRow.java | 9 +- .../sql/catalyst/CatalystTypeConverters.scala | 14 +- .../spark/sql/catalyst/InternalRow.scala | 153 ++++++++++++---- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 168 +++++++++--------- .../expressions/SpecificMutableRow.scala | 4 +- .../sql/catalyst/expressions/aggregates.scala | 4 +- .../codegen/GenerateProjection.scala | 2 +- .../expressions/complexTypeExtractors.scala | 4 +- .../spark/sql/catalyst/expressions/rows.scala | 57 +++--- .../scala/org/apache/spark/sql/RowTest.scala | 10 -- .../sql/catalyst/expressions/CastSuite.scala | 24 ++- .../expressions/ComplexTypeSuite.scala | 7 +- .../spark/sql/columnar/ColumnType.scala | 2 +- .../columnar/InMemoryColumnarTableScan.scala | 12 +- .../sql/execution/SparkSqlSerializer2.scala | 10 +- .../datasources/DataSourceStrategy.scala | 4 +- .../sql/execution/datasources/commands.scala | 53 ++++-- .../spark/sql/execution/datasources/ddl.scala | 16 +- .../spark/sql/execution/pythonUDFs.scala | 4 +- .../sql/expressions/aggregate/udaf.scala | 3 +- .../apache/spark/sql/jdbc/JDBCRelation.scala | 3 +- .../apache/spark/sql/json/JSONRelation.scala | 6 +- .../sql/parquet/CatalystRowConverter.scala | 10 +- .../sql/parquet/ParquetTableOperations.scala | 2 +- .../sql/parquet/ParquetTableSupport.scala | 12 +- .../apache/spark/sql/parquet/newParquet.scala | 6 +- .../apache/spark/sql/sources/interfaces.scala | 22 ++- .../scala/org/apache/spark/sql/RowSuite.scala | 4 +- .../spark/sql/sources/DDLTestSuite.scala | 5 +- .../spark/sql/sources/PrunedScanSuite.scala | 2 +- .../spark/sql/sources/TableScanSuite.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 9 +- .../spark/sql/hive/hiveWriterContainers.scala | 8 +- .../spark/sql/hive/orc/OrcRelation.scala | 8 +- .../CommitFailureTestRelationSuite.scala | 47 +++++ .../ParquetHadoopFsRelationSuite.scala | 139 +++++++++++++++ .../SimpleTextHadoopFsRelationSuite.scala | 57 ++++++ .../sql/sources/hadoopFsRelationSuites.scala | 166 ----------------- 41 files changed, 647 insertions(+), 433 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 55da0e094d132..b6e2c30fbf104 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -174,8 +174,8 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { override def deserialize(datum: Any): Matrix = { datum match { case row: InternalRow => - require(row.length == 7, - s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7") + require(row.numFields == 7, + s"MatrixUDT.deserialize given row with length ${row.numFields} but requires length == 7") val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 9067b3ba9a7bb..c884aad08889f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -203,8 +203,8 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { override def deserialize(datum: Any): Vector = { datum match { case row: InternalRow => - require(row.length == 4, - s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + require(row.numFields == 4, + s"VectorUDT.deserialize given row with length ${row.numFields} but requires length == 4") val tpe = row.getByte(0) tpe match { case 0 => diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index fa1216b455a9e..a8986608855e2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -64,7 +64,8 @@ public final class UnsafeRow extends MutableRow { /** The size of this row's backing data, in bytes) */ private int sizeInBytes; - public int length() { return numFields; } + @Override + public int numFields() { return numFields; } /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; @@ -218,12 +219,12 @@ public void setFloat(int ordinal, float value) { } @Override - public int size() { - return numFields; + public Object get(int i) { + throw new UnsupportedOperationException(); } @Override - public Object get(int i) { + public T getAs(int i) { throw new UnsupportedOperationException(); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index bfaee04f33b7f..5c3072a77aeba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -140,14 +140,14 @@ object CatalystTypeConverters { private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = scalaValue override def toScala(catalystValue: Any): Any = catalystValue - override def toScalaImpl(row: InternalRow, column: Int): Any = row(column) + override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column) } private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) - override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row(column)) + override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column)) } /** Converter for arrays, sequences, and Java iterables. */ @@ -184,7 +184,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row(column).asInstanceOf[Seq[Any]]) + toScala(row.get(column).asInstanceOf[Seq[Any]]) } private case class MapConverter( @@ -227,7 +227,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = - toScala(row(column).asInstanceOf[Map[Any, Any]]) + toScala(row.get(column).asInstanceOf[Map[Any, Any]]) } private case class StructConverter( @@ -260,9 +260,9 @@ object CatalystTypeConverters { if (row == null) { null } else { - val ar = new Array[Any](row.size) + val ar = new Array[Any](row.numFields) var idx = 0 - while (idx < row.size) { + while (idx < row.numFields) { ar(idx) = converters(idx).toScala(row, idx) idx += 1 } @@ -271,7 +271,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Row = - toScala(row(column).asInstanceOf[InternalRow]) + toScala(row.get(column).asInstanceOf[InternalRow]) } private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index c7ec49b3d6c3d..efc4faea569b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -25,48 +25,139 @@ import org.apache.spark.unsafe.types.UTF8String * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -abstract class InternalRow extends Row { +abstract class InternalRow extends Serializable { - def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + def numFields: Int - def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + def get(i: Int): Any - // This is only use for test - override def getString(i: Int): String = getAs[UTF8String](i).toString - - // These expensive API should not be used internally. - final override def getDecimal(i: Int): java.math.BigDecimal = - throw new UnsupportedOperationException - final override def getDate(i: Int): java.sql.Date = - throw new UnsupportedOperationException - final override def getTimestamp(i: Int): java.sql.Timestamp = - throw new UnsupportedOperationException - final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException - final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException - final override def getMap[K, V](i: Int): scala.collection.Map[K, V] = - throw new UnsupportedOperationException - final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] = - throw new UnsupportedOperationException - final override def getStruct(i: Int): Row = throw new UnsupportedOperationException - final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException - final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = - throw new UnsupportedOperationException - - // A default implementation to change the return type - override def copy(): InternalRow = this + // TODO: Remove this. + def apply(i: Int): Any = get(i) + + def getAs[T](i: Int): T = get(i).asInstanceOf[T] + + def isNullAt(i: Int): Boolean = get(i) == null + + def getBoolean(i: Int): Boolean = getAs[Boolean](i) + + def getByte(i: Int): Byte = getAs[Byte](i) + + def getShort(i: Int): Short = getAs[Short](i) + + def getInt(i: Int): Int = getAs[Int](i) + + def getLong(i: Int): Long = getAs[Long](i) + + def getFloat(i: Int): Float = getAs[Float](i) + + def getDouble(i: Int): Double = getAs[Double](i) + + override def toString: String = s"[${this.mkString(",")}]" + + /** + * Make a copy of the current [[InternalRow]] object. + */ + def copy(): InternalRow = this + + /** Returns true if there are any NULL values in this row. */ + def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[InternalRow]) { + return false + } + + val other = o.asInstanceOf[InternalRow] + if (other eq null) { + return false + } + + val len = numFields + if (len != other.numFields) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + /* ---------------------- utility methods for Scala ---------------------- */ /** - * Returns true if we can check equality for these 2 rows. - * Equality check between external row and internal row is not allowed. - * Here we do this check to prevent call `equals` on internal row with external row. + * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - protected override def canEqual(other: Row) = other.isInstanceOf[InternalRow] + def toSeq: Seq[Any] = { + val n = numFields + val values = new Array[Any](n) + var i = 0 + while (i < n) { + values.update(i, get(i)) + i += 1 + } + values.toSeq + } + + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) + + def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + + def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + + // This is only use for test + def getString(i: Int): String = getAs[UTF8String](i).toString // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 var i = 0 - while (i < length) { + val len = numFields + while (i < len) { val update: Int = if (isNullAt(i)) { 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c66854d52c50b..47ad3e089e4c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -382,8 +382,8 @@ case class Cast(child: Expression, dataType: DataType) val newRow = new GenericMutableRow(from.fields.length) buildCast[InternalRow](_, row => { var i = 0 - while (i < row.length) { - newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row(i))) + while (i < row.numFields) { + newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row.get(i))) i += 1 } newRow.copy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 04872fbc8b091..dbda05a792cbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -176,49 +176,49 @@ class JoinedRow extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -278,49 +278,49 @@ class JoinedRow2 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -374,50 +374,50 @@ class JoinedRow3 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -471,50 +471,50 @@ class JoinedRow4 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -568,50 +568,50 @@ class JoinedRow5 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) @@ -665,50 +665,50 @@ class JoinedRow6 extends InternalRow { override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - override def length: Int = row1.length + row2.length + override def numFields: Int = row1.numFields + row2.numFields override def getUTF8String(i: Int): UTF8String = { - if (i < row1.length) row1.getUTF8String(i) else row2.getUTF8String(i - row1.length) + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) } override def getBinary(i: Int): Array[Byte] = { - if (i < row1.length) row1.getBinary(i) else row2.getBinary(i - row1.length) + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } override def get(i: Int): Any = - if (i < row1.length) row1(i) else row2(i - row1.length) + if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = - if (i < row1.length) row1.isNullAt(i) else row2.isNullAt(i - row1.length) + if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) override def getInt(i: Int): Int = - if (i < row1.length) row1.getInt(i) else row2.getInt(i - row1.length) + if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = - if (i < row1.length) row1.getLong(i) else row2.getLong(i - row1.length) + if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) override def getDouble(i: Int): Double = - if (i < row1.length) row1.getDouble(i) else row2.getDouble(i - row1.length) + if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) override def getBoolean(i: Int): Boolean = - if (i < row1.length) row1.getBoolean(i) else row2.getBoolean(i - row1.length) + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) override def getShort(i: Int): Short = - if (i < row1.length) row1.getShort(i) else row2.getShort(i - row1.length) + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) override def getByte(i: Int): Byte = - if (i < row1.length) row1.getByte(i) else row2.getByte(i - row1.length) + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) override def getFloat(i: Int): Float = - if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) override def copy(): InternalRow = { - val totalSize = row1.length + row2.length + val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) var i = 0 while(i < totalSize) { - copiedValues(i) = apply(i) + copiedValues(i) = get(i) i += 1 } new GenericInternalRow(copiedValues) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 6f291d2c86c1e..4b4833bd06a3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -211,7 +211,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR def this() = this(Seq.empty) - override def length: Int = values.length + override def numFields: Int = values.length override def toSeq: Seq[Any] = values.map(_.boxed).toSeq @@ -245,7 +245,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String.fromString(value)) - override def getString(ordinal: Int): String = apply(ordinal).toString + override def getString(ordinal: Int): String = get(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 73fde4e9164d7..62b6cc834c9c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -675,7 +675,7 @@ case class CombineSetsAndSumFunction( val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] val inputIterator = inputSetEval.iterator while (inputIterator.hasNext) { - seen.add(inputIterator.next) + seen.add(inputIterator.next()) } } @@ -685,7 +685,7 @@ case class CombineSetsAndSumFunction( null } else { Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( + casted.iterator.map(f => f.get(0)).reduceLeft( base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), base.dataType).eval(null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 405d6b0e3bc76..f0efc4bff12ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -178,7 +178,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { $initColumns } - public int length() { return ${expressions.length};} + public int numFields() { return ${expressions.length};} protected boolean[] nullBits = new boolean[${expressions.length}]; public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5504781edca1b..c91122cda2a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -110,7 +110,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow](ordinal) + input.asInstanceOf[InternalRow].get(ordinal) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { @@ -142,7 +142,7 @@ case class GetArrayStructFields( protected override def nullSafeEval(input: Any): Any = { input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row(ordinal) + if (row == null) null else row.get(ordinal) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index d78be5a5958f9..53779dd4049d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -44,9 +44,10 @@ abstract class MutableRow extends InternalRow { } override def copy(): InternalRow = { - val arr = new Array[Any](length) + val n = numFields + val arr = new Array[Any](n) var i = 0 - while (i < length) { + while (i < n) { arr(i) = get(i) i += 1 } @@ -54,36 +55,23 @@ abstract class MutableRow extends InternalRow { } } -/** - * A row implementation that uses an array of objects as the underlying storage. - */ -trait ArrayBackedRow { - self: Row => - - protected val values: Array[Any] - - override def toSeq: Seq[Any] = values.toSeq - - def length: Int = values.length - - override def get(i: Int): Any = values(i) - - def setNullAt(i: Int): Unit = { values(i) = null} - - def update(i: Int, value: Any): Unit = { values(i) = value } -} - /** * A row implementation that uses an array of objects as the underlying storage. Note that, while * the array is not copied, and thus could technically be mutated after creation, this is not * allowed. */ -class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow { +class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def length: Int = values.length + + override def get(i: Int): Any = values(i) + + override def toSeq: Seq[Any] = values.toSeq + override def copy(): Row = this } @@ -101,34 +89,49 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) * Note that, while the array is not copied, and thus could technically be mutated after creation, * this is not allowed. */ -class GenericInternalRow(protected[sql] val values: Array[Any]) - extends InternalRow with ArrayBackedRow { +class GenericInternalRow(protected[sql] val values: Array[Any]) extends InternalRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def toSeq: Seq[Any] = values.toSeq + + override def numFields: Int = values.length + + override def get(i: Int): Any = values(i) + override def copy(): InternalRow = this } /** * This is used for serialization of Python DataFrame */ -class GenericInternalRowWithSchema(values: Array[Any], override val schema: StructType) +class GenericInternalRowWithSchema(values: Array[Any], val schema: StructType) extends GenericInternalRow(values) { /** No-arg constructor for serialization. */ protected def this() = this(null, null) - override def fieldIndex(name: String): Int = schema.fieldIndex(name) + def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { +class GenericMutableRow(val values: Array[Any]) extends MutableRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) + override def toSeq: Seq[Any] = values.toSeq + + override def numFields: Int = values.length + + override def get(i: Int): Any = values(i) + + override def setNullAt(i: Int): Unit = { values(i) = null} + + override def update(i: Int, value: Any): Unit = { values(i) = value } + override def copy(): InternalRow = new GenericInternalRow(values.clone()) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 878a1bb9b7e6d..01ff84cb56054 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -83,15 +83,5 @@ class RowTest extends FunSpec with Matchers { it("equality check for internal rows") { internalRow shouldEqual internalRow2 } - - it("throws an exception when check equality between external and internal rows") { - def assertError(f: => Unit): Unit = { - val e = intercept[UnsupportedOperationException](f) - e.getMessage.contains("cannot check equality between external and internal rows") - } - - assertError(internalRow.equals(externalRow)) - assertError(externalRow.equals(internalRow)) - } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index facf65c155148..408353cf70a49 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Test suite for data type casting expression [[Cast]]. @@ -580,14 +581,21 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from struct") { val struct = Literal.create( - InternalRow("123", "abc", "", null), + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("abc"), + UTF8String.fromString(""), + null), StructType(Seq( StructField("a", StringType, nullable = true), StructField("b", StringType, nullable = true), StructField("c", StringType, nullable = true), StructField("d", StringType, nullable = true)))) val struct_notNull = Literal.create( - InternalRow("123", "abc", ""), + InternalRow( + UTF8String.fromString("123"), + UTF8String.fromString("abc"), + UTF8String.fromString("")), StructType(Seq( StructField("a", StringType, nullable = false), StructField("b", StringType, nullable = false), @@ -676,8 +684,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( InternalRow( - Seq("123", "abc", ""), - Map("a" -> "123", "b" -> "abc", "c" -> ""), + Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")), + Map( + UTF8String.fromString("a") -> UTF8String.fromString("123"), + UTF8String.fromString("b") -> UTF8String.fromString("abc"), + UTF8String.fromString("c") -> UTF8String.fromString("")), InternalRow(0)), StructType(Seq( StructField("a", @@ -700,7 +711,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved === true) checkEvaluation(ret, InternalRow( Seq(123, null, null), - Map("a" -> true, "b" -> true, "c" -> false), + Map( + UTF8String.fromString("a") -> true, + UTF8String.fromString("b") -> true, + UTF8String.fromString("c") -> false), InternalRow(0L))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index a8aee8f634e03..fc842772f3480 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -150,12 +151,14 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("CreateNamedStruct with literal field") { val row = InternalRow(1, 2, 3) val c1 = 'a.int.at(0) - checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), InternalRow(1, "y"), row) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), + InternalRow(1, UTF8String.fromString("y")), row) } test("CreateNamedStruct from all literal fields") { checkEvaluation( - CreateNamedStruct(Seq("a", "x", "b", 2.0)), InternalRow("x", 2.0), InternalRow.empty) + CreateNamedStruct(Seq("a", "x", "b", 2.0)), + InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty) } test("test dsl for complex type") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 9d8415f06399c..ac42bde07c37d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -309,7 +309,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { override def actualSize(row: InternalRow, ordinal: Int): Int = { - row.getString(ordinal).getBytes("utf-8").length + 4 + row.getUTF8String(ordinal).numBytes() + 4 } override def append(v: UTF8String, buffer: ByteBuffer): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 38720968c1313..5d5b0697d7016 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -134,13 +134,13 @@ private[sql] case class InMemoryRelation( // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat // hard to decipher. assert( - row.size == columnBuilders.size, - s"""Row column number mismatch, expected ${output.size} columns, but got ${row.size}. - |Row content: $row - """.stripMargin) + row.numFields == columnBuilders.size, + s"Row column number mismatch, expected ${output.size} columns, " + + s"but got ${row.numFields}." + + s"\nRow content: $row") var i = 0 - while (i < row.length) { + while (i < row.numFields) { columnBuilders(i).appendFrom(row, i) i += 1 } @@ -304,7 +304,7 @@ private[sql] case class InMemoryColumnarTableScan( // Extract rows via column accessors new Iterator[InternalRow] { - private[this] val rowLen = nextRow.length + private[this] val rowLen = nextRow.numFields override def next(): InternalRow = { var i = 0 while (i < rowLen) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index c87e2064a8f33..83c4e8733f15f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -25,7 +25,6 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.serializer._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ @@ -53,7 +52,7 @@ private[sql] class Serializer2SerializationStream( private val writeRowFunc = SparkSqlSerializer2.createSerializationFunction(rowSchema, rowOut) override def writeObject[T: ClassTag](t: T): SerializationStream = { - val kv = t.asInstanceOf[Product2[Row, Row]] + val kv = t.asInstanceOf[Product2[InternalRow, InternalRow]] writeKey(kv._1) writeValue(kv._2) @@ -66,7 +65,7 @@ private[sql] class Serializer2SerializationStream( } override def writeValue[T: ClassTag](t: T): SerializationStream = { - writeRowFunc(t.asInstanceOf[Row]) + writeRowFunc(t.asInstanceOf[InternalRow]) this } @@ -205,8 +204,9 @@ private[sql] object SparkSqlSerializer2 { /** * The util function to create the serialization function based on the given schema. */ - def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = { - (row: Row) => + def createSerializationFunction(schema: Array[DataType], out: DataOutputStream) + : InternalRow => Unit = { + (row: InternalRow) => // If the schema is null, the returned function does nothing when it get called. if (schema != null) { var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2b400926177fe..7f452daef33c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -206,7 +206,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val mutableRow = new SpecificMutableRow(dataTypes) iterator.map { dataRow => var i = 0 - while (i < mutableRow.length) { + while (i < mutableRow.numFields) { mergers(i)(mutableRow, dataRow, i) i += 1 } @@ -315,7 +315,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (relation.relation.needConversion) { execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { - rdd.map(_.asInstanceOf[InternalRow]) + rdd.asInstanceOf[RDD[InternalRow]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala index cd2aa7f7433c5..d551f386eee6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/commands.scala @@ -174,14 +174,19 @@ private[sql] case class InsertIntoHadoopFsRelation( try { writerContainer.executorSideSetup(taskContext) - val converter: InternalRow => Row = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + .asInstanceOf[InternalRow => Row] + while (iterator.hasNext) { + val internalRow = iterator.next() + writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) + } } else { - r: InternalRow => r.asInstanceOf[Row] - } - while (iterator.hasNext) { - val internalRow = iterator.next() - writerContainer.outputWriterForRow(internalRow).write(converter(internalRow)) + while (iterator.hasNext) { + val internalRow = iterator.next() + writerContainer.outputWriterForRow(internalRow) + .asInstanceOf[OutputWriterInternal].writeInternal(internalRow) + } } writerContainer.commitTask() @@ -248,17 +253,23 @@ private[sql] case class InsertIntoHadoopFsRelation( val partitionProj = newProjection(codegenEnabled, partitionCasts, output) val dataProj = newProjection(codegenEnabled, dataOutput, output) - val dataConverter: InternalRow => Row = if (needsConversion) { - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + if (needsConversion) { + val converter = CatalystTypeConverters.createToScalaConverter(dataSchema) + .asInstanceOf[InternalRow => Row] + while (iterator.hasNext) { + val internalRow = iterator.next() + val partitionPart = partitionProj(internalRow) + val dataPart = converter(dataProj(internalRow)) + writerContainer.outputWriterForRow(partitionPart).write(dataPart) + } } else { - r: InternalRow => r.asInstanceOf[Row] - } - - while (iterator.hasNext) { - val internalRow = iterator.next() - val partitionPart = partitionProj(internalRow) - val dataPart = dataConverter(dataProj(internalRow)) - writerContainer.outputWriterForRow(partitionPart).write(dataPart) + while (iterator.hasNext) { + val internalRow = iterator.next() + val partitionPart = partitionProj(internalRow) + val dataPart = dataProj(internalRow) + writerContainer.outputWriterForRow(partitionPart) + .asInstanceOf[OutputWriterInternal].writeInternal(dataPart) + } } writerContainer.commitTask() @@ -530,8 +541,12 @@ private[sql] class DynamicPartitionWriterContainer( while (i < partitionColumns.length) { val col = partitionColumns(i) val partitionValueString = { - val string = row.getString(i) - if (string.eq(null)) defaultPartitionName else PartitioningUtils.escapePathName(string) + val string = row.getUTF8String(i) + if (string.eq(null)) { + defaultPartitionName + } else { + PartitioningUtils.escapePathName(string.toString) + } } if (i > 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index c8033d3c0470a..1f2797ec5527a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -23,11 +23,11 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext, SaveMode} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} +import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, InternalRow} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -415,12 +415,12 @@ private[sql] case class CreateTempTableUsing( provider: String, options: Map[String, String]) extends RunnableCommand { - def run(sqlContext: SQLContext): Seq[InternalRow] = { + def run(sqlContext: SQLContext): Seq[Row] = { val resolved = ResolvedDataSource( sqlContext, userSpecifiedSchema, Array.empty[String], provider, options) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty + Seq.empty[Row] } } @@ -432,20 +432,20 @@ private[sql] case class CreateTempTableUsingAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) sqlContext.registerDataFrameAsTable( DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) - Seq.empty + Seq.empty[Row] } } private[sql] case class RefreshTable(databaseName: String, tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Refresh the given table's metadata first. sqlContext.catalog.refreshTable(databaseName, tableName) @@ -464,7 +464,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) sqlContext.cacheManager.cacheQuery(df, Some(tableName)) } - Seq.empty[InternalRow] + Seq.empty[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index e6e27a87c7151..40bf03a3f1a62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -126,9 +126,9 @@ object EvaluatePython { case (null, _) => null case (row: InternalRow, struct: StructType) => - val values = new Array[Any](row.size) + val values = new Array[Any](row.numFields) var i = 0 - while (i < row.size) { + while (i < row.numFields) { values(i) = toJava(row(i), struct.fields(i).dataType) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala index 6c49a906c848a..46f0fac861282 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -148,7 +148,7 @@ class InputAggregationBuffer private[sql] ( toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, - var underlyingInputBuffer: Row) + var underlyingInputBuffer: InternalRow) extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { override def get(i: Int): Any = { @@ -156,6 +156,7 @@ class InputAggregationBuffer private[sql] ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } + // TODO: Use buffer schema to avoid using generic getter. toScalaConverters(i)(underlyingInputBuffer(offsets(i))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala index 4d3aac464c538..41d0ecb4bbfbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala @@ -128,6 +128,7 @@ private[sql] case class JDBCRelation( override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { val driver: String = DriverRegistry.getDriverClassName(url) + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sqlContext.sparkContext, schema, @@ -137,7 +138,7 @@ private[sql] case class JDBCRelation( table, requiredColumns, filters, - parts).map(_.asInstanceOf[Row]) + parts).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala index 922794ac9aac5..562b058414d07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala @@ -154,17 +154,19 @@ private[sql] class JSONRelation( } override def buildScan(): RDD[Row] = { + // Rely on type erasure hack to pass RDD[InternalRow] back as RDD[Row] JacksonParser( baseRDD(), schema, - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) + sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] } override def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JacksonParser( baseRDD(), StructType.fromAttributes(requiredColumns), - sqlContext.conf.columnNameOfCorruptRecord).map(_.asInstanceOf[Row]) + sqlContext.conf.columnNameOfCorruptRecord).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index 0c3d8fdab6bd2..b5e4263008f56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -28,7 +28,7 @@ import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveCo import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema.{GroupType, PrimitiveType, Type} -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -55,8 +55,8 @@ private[parquet] trait ParentContainerUpdater { private[parquet] object NoopUpdater extends ParentContainerUpdater /** - * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[Row]]s. Since - * any Parquet record is also a struct, this converter can also be used as root converter. + * A [[CatalystRowConverter]] is used to convert Parquet "structs" into Spark SQL [[InternalRow]]s. + * Since any Parquet record is also a struct, this converter can also be used as root converter. * * When used as a root converter, [[NoopUpdater]] should be used since root converters don't have * any "parent" container. @@ -108,7 +108,7 @@ private[parquet] class CatalystRowConverter( override def start(): Unit = { var i = 0 - while (i < currentRow.length) { + while (i < currentRow.numFields) { currentRow.setNullAt(i) i += 1 } @@ -178,7 +178,7 @@ private[parquet] class CatalystRowConverter( case t: StructType => new CatalystRowConverter(parquetType.asGroupType(), t, new ParentContainerUpdater { - override def set(value: Any): Unit = updater.set(value.asInstanceOf[Row].copy()) + override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy()) }) case t: UserDefinedType[_] => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 28cba5e54d69e..8cab27d6e1c46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -178,7 +178,7 @@ private[sql] case class ParquetTableScan( val row = iter.next()._2.asInstanceOf[InternalRow] var i = 0 - while (i < row.size) { + while (i < row.numFields) { mutableRow(i) = row(i) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index d1040bf5562a2..c7c58e69d42ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -208,9 +208,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo override def write(record: InternalRow): Unit = { val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") } var index = 0 @@ -378,9 +378,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] class MutableRowWriteSupport extends RowWriteSupport { override def write(record: InternalRow): Unit = { val attributesSize = attributes.size - if (attributesSize > record.size) { - throw new IndexOutOfBoundsException( - s"Trying to write more fields than contained in row ($attributesSize > ${record.size})") + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") } var index = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index c384697c0ee62..8ec228c2b25bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -61,7 +61,7 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider { // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriter { + extends OutputWriterInternal { private val recordWriter: RecordWriter[Void, InternalRow] = { val outputFormat = { @@ -86,7 +86,7 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext outputFormat.getRecordWriter(context) } - override def write(row: Row): Unit = recordWriter.write(null, row.asInstanceOf[InternalRow]) + override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } @@ -324,7 +324,7 @@ private[sql] class ParquetRelation2( new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) } } - }.values.map(_.asInstanceOf[Row]) + }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7cd005b959488..119bac786d478 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -344,6 +344,18 @@ abstract class OutputWriter { def close(): Unit } +/** + * This is an internal, private version of [[OutputWriter]] with an writeInternal method that + * accepts an [[InternalRow]] rather than an [[Row]]. Data sources that return this must have + * the conversion flag set to false. + */ +private[sql] abstract class OutputWriterInternal extends OutputWriter { + + override def write(row: Row): Unit = throw new UnsupportedOperationException + + def writeInternal(row: InternalRow): Unit +} + /** * ::Experimental:: * A [[BaseRelation]] that provides much of the common code required for formats that store their @@ -592,12 +604,12 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio BoundReference(dataSchema.fieldIndex(col), field.dataType, field.nullable) }.toSeq - val rdd = buildScan(inputFiles) - val converted = + val rdd: RDD[Row] = buildScan(inputFiles) + val converted: RDD[InternalRow] = if (needConversion) { RDDConversions.rowToRowRdd(rdd, dataSchema.fields.map(_.dataType)) } else { - rdd.map(_.asInstanceOf[InternalRow]) + rdd.asInstanceOf[RDD[InternalRow]] } converted.mapPartitions { rows => val buildProjection = if (codegenEnabled) { @@ -606,8 +618,8 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) } val mutableProjection = buildProjection() - rows.map(r => mutableProjection(r).asInstanceOf[Row]) - } + rows.map(r => mutableProjection(r)) + }.asInstanceOf[RDD[Row]] } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 7cc6ffd7548d0..0e5c5abff85f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -35,14 +35,14 @@ class RowSuite extends SparkFunSuite { expected.update(2, false) expected.update(3, null) val actual1 = Row(2147483647, "this is a string", false, null) - assert(expected.size === actual1.size) + assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) assert(expected(3) === actual1(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) - assert(expected.size === actual2.size) + assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index da53ec16b5c41..84855ce45e918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -61,9 +61,10 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo override def needConversion: Boolean = false override def buildScan(): RDD[Row] = { + // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] sqlContext.sparkContext.parallelize(from to to).map { e => - InternalRow(UTF8String.fromString(s"people$e"), e * 2): Row - } + InternalRow(UTF8String.fromString(s"people$e"), e * 2) + }.asInstanceOf[RDD[Row]] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 257526feab945..0d5183444af78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -131,7 +131,7 @@ class PrunedScanSuite extends DataSourceTest { queryExecution) } - if (rawOutput.size != expectedColumns.size) { + if (rawOutput.numFields != expectedColumns.size) { fail(s"Wrong output row. Got $rawOutput\n$queryExecution") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 143aadc08b1c4..5e189c3563ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -93,7 +93,7 @@ case class AllDataTypesScan( InternalRow(i, UTF8String.fromString(i.toString)), InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) - } + }.asInstanceOf[RDD[Row]] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 8202e553afbfe..34b629403e128 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -122,7 +122,7 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { + protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc @@ -252,13 +252,12 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - Seq.empty[InternalRow] + Seq.empty[Row] } - override def executeCollect(): Array[Row] = - sideEffectResult.toArray + override def executeCollect(): Array[Row] = sideEffectResult.toArray protected override def doExecute(): RDD[InternalRow] = { - sqlContext.sparkContext.parallelize(sideEffectResult, 1) + sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ecc78a5f8d321..8850e060d2a73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.hive.common.FileUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.Row import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types._ @@ -94,7 +95,9 @@ private[hive] class SparkHiveWriterContainer( "part-" + numberFormat.format(splitID) + extension } - def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = writer + def getLocalFileWriter(row: InternalRow, schema: StructType): FileSinkOperator.RecordWriter = { + writer + } def close() { // Seems the boolean value passed into close does not matter. @@ -197,7 +200,8 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( jobConf.setBoolean(SUCCESSFUL_JOB_OUTPUT_DIR_MARKER, oldMarker) } - override def getLocalFileWriter(row: Row, schema: StructType): FileSinkOperator.RecordWriter = { + override def getLocalFileWriter(row: InternalRow, schema: StructType) + : FileSinkOperator.RecordWriter = { def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index de63ee56dd8e6..10623dc820316 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -66,7 +66,7 @@ private[orc] class OrcOutputWriter( path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriterInternal with SparkHadoopMapRedUtil with HiveInspectors { private val serializer = { val table = new Properties() @@ -119,9 +119,9 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = { + override def writeInternal(row: InternalRow): Unit = { var i = 0 - while (i < row.length) { + while (i < row.numFields) { reusableOutputBuffer(i) = wrappers(i)(row(i)) i += 1 } @@ -192,7 +192,7 @@ private[sql] class OrcRelation( filters: Array[Filter], inputPaths: Array[FileStatus]): RDD[Row] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row]) + OrcTableScan(output, this, filters, inputPaths).execute().asInstanceOf[RDD[Row]] } override def prepareJobForWrite(job: Job): OutputWriterFactory = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala new file mode 100644 index 0000000000000..e976125b3706d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils + + +class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { + override val sqlContext = TestHive + + // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. + val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName + + test("SPARK-7684: commitTask() failure should fallback to abortTask()") { + withTempPath { file => + // Here we coalesce partition number to 1 to ensure that only a single task is issued. This + // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` + // directory while committing/aborting the job. See SPARK-8513 for more details. + val df = sqlContext.range(0, 10).coalesce(1) + intercept[SparkException] { + df.write.format(dataSourceName).save(file.getCanonicalPath) + } + + val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) + assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala new file mode 100644 index 0000000000000..d280543a071d9 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/ParquetHadoopFsRelationSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import com.google.common.io.Files +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{AnalysisException, SaveMode, parquet} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + + +class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName + + import sqlContext._ + import sqlContext.implicits._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) + .toDF("a", "b", "p1") + .write.parquet(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } + + test("SPARK-7868: _temporary directories should be ignored") { + withTempPath { dir => + val df = Seq("a", "b", "c").zipWithIndex.toDF() + + df.write + .format("parquet") + .save(dir.getCanonicalPath) + + df.write + .format("parquet") + .save(s"${dir.getCanonicalPath}/_temporary") + + checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) + } + } + + test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { + withTempDir { dir => + val path = dir.getCanonicalPath + val df = Seq(1 -> "a").toDF() + + // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw + // since it's not a valid Parquet file. + val emptyFile = new File(path, "empty") + Files.createParentDirs(emptyFile) + Files.touch(emptyFile) + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Ignore).save(path) + + // This should only complain that the destination directory already exists, rather than file + // "empty" is not a Parquet file. + assert { + intercept[AnalysisException] { + df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) + }.getMessage.contains("already exists") + } + + // This shouldn't throw anything. + df.write.format("parquet").mode(SaveMode.Overwrite).save(path) + checkAnswer(read.format("parquet").load(path), df) + } + } + + test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { + withTempPath { dir => + intercept[AnalysisException] { + // Parquet doesn't allow field names with spaces. Here we are intentionally making an + // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger + // the bug. Please refer to spark-8079 for more details. + range(1, 10) + .withColumnRenamed("id", "a b") + .write + .format("parquet") + .save(dir.getCanonicalPath) + } + } + } + + test("SPARK-8604: Parquet data source should write summary file while doing appending") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(0, 5) + df.write.mode(SaveMode.Overwrite).parquet(path) + + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") + + val fs = summaryPath.getFileSystem(configuration) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) + + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala new file mode 100644 index 0000000000000..d761909d60e21 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/* +This is commented out due a bug in the data source API (SPARK-9291). + + +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { + override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName + + import sqlContext._ + + test("save()/load() - partitioned table - simple queries - partition columns in data") { + withTempDir { file => + val basePath = new Path(file.getCanonicalPath) + val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) + val qualifiedBasePath = fs.makeQualified(basePath) + + for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { + val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") + sparkContext + .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") + .saveAsTextFile(partitionDir.toString) + } + + val dataSchemaWithPartition = + StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) + + checkQueries( + read.format(dataSourceName) + .option("dataSchema", dataSchemaWithPartition.json) + .load(file.getCanonicalPath)) + } + } +} +*/ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 2a8748d913569..dd274023a1cf5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -17,18 +17,14 @@ package org.apache.spark.sql.sources -import java.io.File - import scala.collection.JavaConversions._ -import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.parquet.hadoop.ParquetOutputCommitter -import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -581,165 +577,3 @@ class AlwaysFailParquetOutputCommitter( sys.error("Intentional job commitment failure for testing purpose.") } } - -class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName - - import sqlContext._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield s"$i,val_$i,$p1") - .saveAsTextFile(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } -} - -class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils { - override val sqlContext = TestHive - - // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. - val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName - - test("SPARK-7684: commitTask() failure should fallback to abortTask()") { - withTempPath { file => - // Here we coalesce partition number to 1 to ensure that only a single task is issued. This - // prevents race condition happened when FileOutputCommitter tries to remove the `_temporary` - // directory while committing/aborting the job. See SPARK-8513 for more details. - val df = sqlContext.range(0, 10).coalesce(1) - intercept[SparkException] { - df.write.format(dataSourceName).save(file.getCanonicalPath) - } - - val fs = new Path(file.getCanonicalPath).getFileSystem(SparkHadoopUtil.get.conf) - assert(!fs.exists(new Path(file.getCanonicalPath, "_temporary"))) - } - } -} - -class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { - override val dataSourceName: String = classOf[parquet.DefaultSource].getCanonicalName - - import sqlContext._ - import sqlContext.implicits._ - - test("save()/load() - partitioned table - simple queries - partition columns in data") { - withTempDir { file => - val basePath = new Path(file.getCanonicalPath) - val fs = basePath.getFileSystem(SparkHadoopUtil.get.conf) - val qualifiedBasePath = fs.makeQualified(basePath) - - for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) { - val partitionDir = new Path(qualifiedBasePath, s"p1=$p1/p2=$p2") - sparkContext - .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) - .toDF("a", "b", "p1") - .write.parquet(partitionDir.toString) - } - - val dataSchemaWithPartition = - StructType(dataSchema.fields :+ StructField("p1", IntegerType, nullable = true)) - - checkQueries( - read.format(dataSourceName) - .option("dataSchema", dataSchemaWithPartition.json) - .load(file.getCanonicalPath)) - } - } - - test("SPARK-7868: _temporary directories should be ignored") { - withTempPath { dir => - val df = Seq("a", "b", "c").zipWithIndex.toDF() - - df.write - .format("parquet") - .save(dir.getCanonicalPath) - - df.write - .format("parquet") - .save(s"${dir.getCanonicalPath}/_temporary") - - checkAnswer(read.format("parquet").load(dir.getCanonicalPath), df.collect()) - } - } - - test("SPARK-8014: Avoid scanning output directory when SaveMode isn't SaveMode.Append") { - withTempDir { dir => - val path = dir.getCanonicalPath - val df = Seq(1 -> "a").toDF() - - // Creates an arbitrary file. If this directory gets scanned, ParquetRelation2 will throw - // since it's not a valid Parquet file. - val emptyFile = new File(path, "empty") - Files.createParentDirs(emptyFile) - Files.touch(emptyFile) - - // This shouldn't throw anything. - df.write.format("parquet").mode(SaveMode.Ignore).save(path) - - // This should only complain that the destination directory already exists, rather than file - // "empty" is not a Parquet file. - assert { - intercept[AnalysisException] { - df.write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) - }.getMessage.contains("already exists") - } - - // This shouldn't throw anything. - df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - checkAnswer(read.format("parquet").load(path), df) - } - } - - test("SPARK-8079: Avoid NPE thrown from BaseWriterContainer.abortJob") { - withTempPath { dir => - intercept[AnalysisException] { - // Parquet doesn't allow field names with spaces. Here we are intentionally making an - // exception thrown from the `ParquetRelation2.prepareForWriteJob()` method to trigger - // the bug. Please refer to spark-8079 for more details. - range(1, 10) - .withColumnRenamed("id", "a b") - .write - .format("parquet") - .save(dir.getCanonicalPath) - } - } - } - - test("SPARK-8604: Parquet data source should write summary file while doing appending") { - withTempPath { dir => - val path = dir.getCanonicalPath - val df = sqlContext.range(0, 5) - df.write.mode(SaveMode.Overwrite).parquet(path) - - val summaryPath = new Path(path, "_metadata") - val commonSummaryPath = new Path(path, "_common_metadata") - - val fs = summaryPath.getFileSystem(configuration) - fs.delete(summaryPath, true) - fs.delete(commonSummaryPath, true) - - df.write.mode(SaveMode.Append).parquet(path) - checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) - - assert(fs.exists(summaryPath)) - assert(fs.exists(commonSummaryPath)) - } - } -} From c8d71a4183dfc83ff257047857af0b6d66c6b90d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 Jul 2015 09:38:13 -0700 Subject: [PATCH 048/219] [SPARK-9305] Rename org.apache.spark.Row to Item. It's a thing used in test cases, but named Row. Pretty annoying because everytime I search for Row, it shows up before the Spark SQL Row, which is what a developer wants most of the time. Author: Reynold Xin Closes #7638 from rxin/remove-row and squashes the following commits: aeda52d [Reynold Xin] [SPARK-9305] Rename org.apache.spark.Row to Item. --- .../scala/org/apache/spark/PartitioningSuite.scala | 10 +++++----- .../org/apache/spark/sql/RandomDataGenerator.scala | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 3316f561a4949..aa8028792cb41 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -91,13 +91,13 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva test("RangePartitioner for keys that are not Comparable (but with Ordering)") { // Row does not extend Comparable, but has an implicit Ordering defined. - implicit object RowOrdering extends Ordering[Row] { - override def compare(x: Row, y: Row): Int = x.value - y.value + implicit object RowOrdering extends Ordering[Item] { + override def compare(x: Item, y: Item): Int = x.value - y.value } - val rdd = sc.parallelize(1 to 4500).map(x => (Row(x), Row(x))) + val rdd = sc.parallelize(1 to 4500).map(x => (Item(x), Item(x))) val partitioner = new RangePartitioner(1500, rdd) - partitioner.getPartition(Row(100)) + partitioner.getPartition(Item(100)) } test("RangPartitioner.sketch") { @@ -252,4 +252,4 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva } -private sealed case class Row(value: Int) +private sealed case class Item(value: Int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index b9f2ad7ec0481..75ae29d690770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -69,8 +69,7 @@ object RandomDataGenerator { * Returns a function which generates random values for the given [[DataType]], or `None` if no * random data generator is defined for that data type. The generated values will use an external * representation of the data type; for example, the random generator for [[DateType]] will return - * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a - * [[org.apache.spark.Row]]. + * instances of [[java.sql.Date]] and the generator for [[StructType]] will return a [[Row]]. * * @param dataType the type to generate values for * @param nullable whether null values should be generated From c2b50d693e469558e3b3c9cbb9d76089d259b587 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Jul 2015 09:49:50 -0700 Subject: [PATCH 049/219] [SPARK-9292] Analysis should check that join conditions' data types are BooleanType This patch adds an analysis check to ensure that join conditions' data types are BooleanType. This check is necessary in order to report proper errors for non-boolean DataFrame join conditions. Author: Josh Rosen Closes #7630 from JoshRosen/SPARK-9292 and squashes the following commits: aec6c7b [Josh Rosen] Check condition type in resolved() 75a3ea6 [Josh Rosen] Fix SPARK-9292. --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 5 +++++ .../spark/sql/catalyst/plans/logical/basicOperators.scala | 5 ++++- .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 5 +++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c203fcecf20fb..c23ab3c74338d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -83,6 +83,11 @@ trait CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => + failAnalysis( + s"join condition '${condition.prettyString}' " + + s"of type ${condition.dataType.simpleString} is not a boolean.") + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 6aefa9f67556a..57a12820fa4c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -128,7 +128,10 @@ case class Join( // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = { - childrenResolved && expressions.forall(_.resolved) && selfJoinResolved + childrenResolved && + expressions.forall(_.resolved) && + selfJoinResolved && + condition.forall(_.dataType == BooleanType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index dca8c881f21ab..7bf678ebf71ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -118,6 +118,11 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { testRelation.where(Literal(1)), "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + errorTest( + "non-boolean join conditions", + testRelation.join(testRelation, condition = Some(Literal(1))), + "condition" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + errorTest( "missing group by", testRelation2.groupBy('a)('b), From e25312451322969ad716dddf8248b8c17f68323b Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 24 Jul 2015 10:56:48 -0700 Subject: [PATCH 050/219] [SPARK-9222] [MLlib] Make class instantiation variables in DistributedLDAModel private[clustering] This makes it easier to test all the class variables of the DistributedLDAmodel. Author: MechCoder Closes #7573 from MechCoder/lda_test and squashes the following commits: 2f1a293 [MechCoder] [SPARK-9222] [MLlib] Make class instantiation variables in DistributedLDAModel private[clustering] --- .../apache/spark/mllib/clustering/LDAModel.scala | 8 ++++---- .../apache/spark/mllib/clustering/LDASuite.scala | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 920b57756b625..31c1d520fd659 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -283,12 +283,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] { */ @Experimental class DistributedLDAModel private ( - private val graph: Graph[LDA.TopicCounts, LDA.TokenCount], - private val globalTopicTotals: LDA.TopicCounts, + private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], + private[clustering] val globalTopicTotals: LDA.TopicCounts, val k: Int, val vocabSize: Int, - private val docConcentration: Double, - private val topicConcentration: Double, + private[clustering] val docConcentration: Double, + private[clustering] val topicConcentration: Double, private[spark] val iterationTimes: Array[Double]) extends LDAModel { import LDA._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index da70d9bd7c790..376a87f0511b4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite +import org.apache.spark.graphx.Edge import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -318,6 +319,20 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(distributedModel.k === sameDistributedModel.k) assert(distributedModel.vocabSize === sameDistributedModel.vocabSize) assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) + assert(distributedModel.docConcentration === sameDistributedModel.docConcentration) + assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration) + assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals) + + val graph = distributedModel.graph + val sameGraph = sameDistributedModel.graph + assert(graph.vertices.sortByKey().collect() === sameGraph.vertices.sortByKey().collect()) + val edge = graph.edges.map { + case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos) + }.sortBy(x => (x._1, x._2)).collect() + val sameEdge = sameGraph.edges.map { + case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos) + }.sortBy(x => (x._1, x._2)).collect() + assert(edge === sameEdge) } finally { Utils.deleteRecursively(tempDir1) Utils.deleteRecursively(tempDir2) From 6aceaf3d62ee335570ddc07ccaf07e8c3776f517 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Jul 2015 11:34:23 -0700 Subject: [PATCH 051/219] [SPARK-9295] Analysis should detect sorting on unsupported column types This patch extends CheckAnalysis to throw errors for queries that try to sort on unsupported column types, such as ArrayType. Author: Josh Rosen Closes #7633 from JoshRosen/SPARK-9295 and squashes the following commits: 23b2fbf [Josh Rosen] Embed function in foreach bfe1451 [Josh Rosen] Update to allow sorting by null literals 2f1b802 [Josh Rosen] Add analysis rule to detect sorting on unsupported column types (SPARK-9295) --- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 10 ++++++++++ .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c23ab3c74338d..81d473c1130f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -103,6 +103,16 @@ trait CheckAnalysis { aggregateExprs.foreach(checkValidAggregateExpression) + case Sort(orders, _, _) => + orders.foreach { order => + order.dataType match { + case t: AtomicType => // OK + case NullType => // OK + case t => + failAnalysis(s"Sorting is not supported for columns of type ${t.simpleString}") + } + } + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 7bf678ebf71ce..2588df98246dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -113,6 +113,11 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { testRelation.select(Literal(1).cast(BinaryType).as('badCast)), "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + errorTest( + "sorting by unsupported column types", + listRelation.orderBy('list.asc), + "sorting" :: "type" :: "array" :: Nil) + errorTest( "non-boolean filters", testRelation.where(Literal(1)), From 8399ba14873854ab2f80a0ccaf6adba499060365 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 24 Jul 2015 11:53:16 -0700 Subject: [PATCH 052/219] [SPARK-9261] [STREAMING] Avoid calling APIs that expose shaded classes. Doing this may cause weird errors when tests are run on maven, depending on the flags used. Instead, expose the needed functionality through methods that do not expose shaded classes. Author: Marcelo Vanzin Closes #7601 from vanzin/SPARK-9261 and squashes the following commits: 4f64a16 [Marcelo Vanzin] [SPARK-9261] [streaming] Avoid calling APIs that expose shaded classes. --- .../scala/org/apache/spark/ui/WebUI.scala | 19 +++++++++++++++++++ .../spark/streaming/ui/StreamingTab.scala | 12 +++--------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 2c84e4485996e..61449847add3d 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -107,6 +107,25 @@ private[spark] abstract class WebUI( } } + /** + * Add a handler for static content. + * + * @param resourceBase Root of where to find resources to serve. + * @param path Path in UI where to mount the resources. + */ + def addStaticHandler(resourceBase: String, path: String): Unit = { + attachHandler(JettyUtils.createStaticHandler(resourceBase, path)) + } + + /** + * Remove a static content handler. + * + * @param path Path in UI to unmount. + */ + def removeStaticHandler(path: String): Unit = { + handlers.find(_.getContextPath() == path).foreach(detachHandler) + } + /** Initialize all components of the server. */ def initialize() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index e0c0f57212f55..bc53f2a31f6d1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -17,11 +17,9 @@ package org.apache.spark.streaming.ui -import org.eclipse.jetty.servlet.ServletContextHandler - import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.ui.{JettyUtils, SparkUI, SparkUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab} import StreamingTab._ @@ -42,18 +40,14 @@ private[spark] class StreamingTab(val ssc: StreamingContext) attachPage(new StreamingPage(this)) attachPage(new BatchPage(this)) - var staticHandler: ServletContextHandler = null - def attach() { getSparkUI(ssc).attachTab(this) - staticHandler = JettyUtils.createStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") - getSparkUI(ssc).attachHandler(staticHandler) + getSparkUI(ssc).addStaticHandler(STATIC_RESOURCE_DIR, "/static/streaming") } def detach() { getSparkUI(ssc).detachTab(this) - getSparkUI(ssc).detachHandler(staticHandler) - staticHandler = null + getSparkUI(ssc).removeStaticHandler("/static/streaming") } } From 9a11396113d4bb0e76e0520df4fc58e7a8ec9f69 Mon Sep 17 00:00:00 2001 From: Cheolsoo Park Date: Fri, 24 Jul 2015 11:56:55 -0700 Subject: [PATCH 053/219] [SPARK-9270] [PYSPARK] allow --name option in pyspark This is continuation of #7512 which added `--name` option to spark-shell. This PR adds the same option to pyspark. Note that `--conf spark.app.name` in command-line has no effect in spark-shell and pyspark. Instead, `--name` must be used. This is in fact inconsistency with spark-sql which doesn't accept `--name` option while it accepts `--conf spark.app.name`. I am not fixing this inconsistency in this PR. IMO, one of `--name` and `--conf spark.app.name` is needed not both. But since I cannot decide which to choose, I am not making any change here. Author: Cheolsoo Park Closes #7610 from piaozhexiu/SPARK-9270 and squashes the following commits: 763e86d [Cheolsoo Park] Update windows script 400b7f9 [Cheolsoo Park] Allow --name option to pyspark --- bin/pyspark | 2 +- bin/pyspark2.cmd | 2 +- python/pyspark/shell.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bin/pyspark b/bin/pyspark index f9dbddfa53560..8f2a3b5a7717b 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -82,4 +82,4 @@ fi export PYSPARK_DRIVER_PYTHON export PYSPARK_DRIVER_PYTHON_OPTS -exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main "$@" +exec "$SPARK_HOME"/bin/spark-submit pyspark-shell-main --name "PySparkShell" "$@" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index 45e9e3def5121..3c6169983e76b 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -35,4 +35,4 @@ set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.8.2.1-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py -call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main %* +call %SPARK_HOME%\bin\spark-submit2.cmd pyspark-shell-main --name "PySparkShell" %* diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 144cdf0b0cdd5..99331297c19f0 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -40,7 +40,7 @@ if os.environ.get("SPARK_EXECUTOR_URI"): SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) -sc = SparkContext(appName="PySparkShell", pyFiles=add_files) +sc = SparkContext(pyFiles=add_files) atexit.register(lambda: sc.stop()) try: From 64135cbb3363e3b74dad3c0498cb9959c047d381 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Jul 2015 12:36:44 -0700 Subject: [PATCH 054/219] [SPARK-9067] [SQL] Close reader in NewHadoopRDD early if there is no more data JIRA: https://issues.apache.org/jira/browse/SPARK-9067 According to the description of the JIRA ticket, calling `reader.close()` only after the task is finished will cause memory and file open limit problem since these resources are occupied even we don't need that anymore. This PR simply closes the reader early when we know there is no more data to read. Author: Liang-Chi Hsieh Closes #7424 from viirya/close_reader and squashes the following commits: 3ff64e5 [Liang-Chi Hsieh] For comments. 3d20267 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into close_reader e152182 [Liang-Chi Hsieh] For comments. 5116cbe [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into close_reader 3ceb755 [Liang-Chi Hsieh] For comments. e34d98e [Liang-Chi Hsieh] For comments. 50ed729 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into close_reader 216912f [Liang-Chi Hsieh] Fix it. f429016 [Liang-Chi Hsieh] Release reader if we don't need it. a305621 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into close_reader 67569da [Liang-Chi Hsieh] Close reader early if there is no more data. --- .../org/apache/spark/rdd/NewHadoopRDD.scala | 37 ++++++++++++------- .../spark/sql/execution/SqlNewHadoopRDD.scala | 36 +++++++++++------- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index f827270ee6a44..f83a051f5da11 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -128,7 +128,7 @@ class NewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val reader = format.createRecordReader( + private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -141,6 +141,12 @@ class NewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } havePair = !finished } !finished @@ -159,18 +165,23 @@ class NewHadoopRDD[K, V]( private def close() { try { - reader.close() - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + if (reader != null) { + // Close reader and release it + reader.close() + reader = null + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } } } } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala index e1c1a6c06268f..3d75b6a91def6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala @@ -147,7 +147,7 @@ private[sql] class SqlNewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val reader = format.createRecordReader( + private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -160,6 +160,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( override def hasNext: Boolean = { if (!finished && !havePair) { finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } havePair = !finished } !finished @@ -178,18 +184,22 @@ private[sql] class SqlNewHadoopRDD[K, V]( private def close() { try { - reader.close() - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) + if (reader != null) { + reader.close() + reader = null + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } } } } catch { From a400ab516fa93185aa683a596f9d7c6c1a02f330 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 24 Jul 2015 14:58:07 -0700 Subject: [PATCH 055/219] [SPARK-7045] [MLLIB] Avoid intermediate representation when creating model Word2Vec used to convert from an Array[Float] representation to a Map[String, Array[Float]] and then back to an Array[Float] through Word2VecModel. This prevents this conversion while still supporting the older method of supplying a Map. Author: MechCoder Closes #5748 from MechCoder/spark-7045 and squashes the following commits: e308913 [MechCoder] move docs 5703116 [MechCoder] minor fa04313 [MechCoder] style fixes b1d61c4 [MechCoder] better errors and tests 3b32c8c [MechCoder] [SPARK-7045] Avoid intermediate representation when creating model --- .../apache/spark/mllib/feature/Word2Vec.scala | 85 +++++++++++-------- .../spark/mllib/feature/Word2VecSuite.scala | 6 ++ 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f087d06d2a46a..cbbd2b0c8d060 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -403,17 +403,8 @@ class Word2Vec extends Serializable with Logging { } newSentences.unpersist() - val word2VecMap = mutable.HashMap.empty[String, Array[Float]] - var i = 0 - while (i < vocabSize) { - val word = bcVocab.value(i).word - val vector = new Array[Float](vectorSize) - Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) - word2VecMap += word -> vector - i += 1 - } - - new Word2VecModel(word2VecMap.toMap) + val wordArray = vocab.map(_.word) + new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) } /** @@ -429,38 +420,42 @@ class Word2Vec extends Serializable with Logging { /** * :: Experimental :: * Word2Vec model + * @param wordIndex maps each word to an index, which can retrieve the corresponding + * vector from wordVectors + * @param wordVectors array of length numWords * vectorSize, vector corresponding + * to the word mapped with index i can be retrieved by the slice + * (i * vectorSize, i * vectorSize + vectorSize) */ @Experimental -class Word2VecModel private[spark] ( - model: Map[String, Array[Float]]) extends Serializable with Saveable { - - // wordList: Ordered list of words obtained from model. - private val wordList: Array[String] = model.keys.toArray - - // wordIndex: Maps each word to an index, which can retrieve the corresponding - // vector from wordVectors (see below). - private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap +class Word2VecModel private[mllib] ( + private val wordIndex: Map[String, Int], + private val wordVectors: Array[Float]) extends Serializable with Saveable { - // vectorSize: Dimension of each word's vector. - private val vectorSize = model.head._2.size private val numWords = wordIndex.size + // vectorSize: Dimension of each word's vector. + private val vectorSize = wordVectors.length / numWords + + // wordList: Ordered list of words obtained from wordIndex. + private val wordList: Array[String] = { + val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip + wl.toArray + } - // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word - // mapped with index i can be retrieved by the slice - // (ind * vectorSize, ind * vectorSize + vectorSize) // wordVecNorms: Array of length numWords, each value being the Euclidean norm // of the wordVector. - private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = { - val wordVectors = new Array[Float](vectorSize * numWords) + private val wordVecNorms: Array[Double] = { val wordVecNorms = new Array[Double](numWords) var i = 0 while (i < numWords) { - val vec = model.get(wordList(i)).get - Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize) + val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) i += 1 } - (wordVectors, wordVecNorms) + wordVecNorms + } + + def this(model: Map[String, Array[Float]]) = { + this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) } private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { @@ -484,8 +479,9 @@ class Word2VecModel private[spark] ( * @return vector representation of word */ def transform(word: String): Vector = { - model.get(word) match { - case Some(vec) => + wordIndex.get(word) match { + case Some(ind) => + val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize) Vectors.dense(vec.map(_.toDouble)) case None => throw new IllegalStateException(s"$word not in vocabulary") @@ -511,7 +507,7 @@ class Word2VecModel private[spark] ( */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - + // TODO: optimize top-k val fVector = vector.toArray.map(_.toFloat) val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 @@ -521,13 +517,13 @@ class Word2VecModel private[spark] ( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) // Need not divide with the norm of the given vector since it is constant. - val updatedCosines = new Array[Double](numWords) + val cosVec = cosineVec.map(_.toDouble) var ind = 0 while (ind < numWords) { - updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind) + cosVec(ind) /= wordVecNorms(ind) ind += 1 } - wordList.zip(updatedCosines) + wordList.zip(cosVec) .toSeq .sortBy(- _._2) .take(num + 1) @@ -548,6 +544,23 @@ class Word2VecModel private[spark] ( @Experimental object Word2VecModel extends Loader[Word2VecModel] { + private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = { + model.keys.zipWithIndex.toMap + } + + private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = { + require(model.nonEmpty, "Word2VecMap should be non-empty") + val (vectorSize, numWords) = (model.head._2.size, model.size) + val wordList = model.keys.toArray + val wordVectors = new Array[Float](vectorSize * numWords) + var i = 0 + while (i < numWords) { + Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize) + i += 1 + } + wordVectors + } + private object SaveLoadV1_0 { val formatVersionV1_0 = "1.0" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index b6818369208d7..4cc8d1129b858 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -37,6 +37,12 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(syms.length == 2) assert(syms(0)._1 == "b") assert(syms(1)._1 == "c") + + // Test that model built using Word2Vec, i.e wordVectors and wordIndec + // and a Word2VecMap give the same values. + val word2VecMap = model.getVectors + val newModel = new Word2VecModel(word2VecMap) + assert(newModel.getVectors.mapValues(_.toSeq) === word2VecMap.mapValues(_.toSeq)) } test("Word2VecModel") { From f99cb5615cbc0b469d52af6bd08f8bf888af58f3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 Jul 2015 19:29:01 -0700 Subject: [PATCH 056/219] [SPARK-9330][SQL] Create specialized getStruct getter in InternalRow. Also took the chance to rearrange some of the methods in UnsafeRow to group static/private/public things together. Author: Reynold Xin Closes #7654 from rxin/getStruct and squashes the following commits: b491a09 [Reynold Xin] Fixed typo. 48d77e5 [Reynold Xin] [SPARK-9330][SQL] Create specialized getStruct getter in InternalRow. --- .../sql/catalyst/expressions/UnsafeRow.java | 87 ++++++++++++------- .../sql/catalyst/CatalystTypeConverters.scala | 2 +- .../spark/sql/catalyst/InternalRow.scala | 22 +++-- .../catalyst/expressions/BoundAttribute.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 5 +- 5 files changed, 77 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a8986608855e2..225f6e6553d19 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -51,28 +51,9 @@ */ public final class UnsafeRow extends MutableRow { - private Object baseObject; - private long baseOffset; - - public Object getBaseObject() { return baseObject; } - public long getBaseOffset() { return baseOffset; } - public int getSizeInBytes() { return sizeInBytes; } - - /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ - private int numFields; - - /** The size of this row's backing data, in bytes) */ - private int sizeInBytes; - - @Override - public int numFields() { return numFields; } - - /** The width of the null tracking bit set, in bytes */ - private int bitSetWidthInBytes; - - private long getFieldOffset(int ordinal) { - return baseOffset + bitSetWidthInBytes + ordinal * 8L; - } + ////////////////////////////////////////////////////////////////////////////// + // Static methods + ////////////////////////////////////////////////////////////////////////////// public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; @@ -103,7 +84,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { DoubleType, DateType, TimestampType - }))); + }))); // We support get() on a superset of the types for which we support set(): final Set _readableFieldTypes = new HashSet<>( @@ -115,12 +96,48 @@ public static int calculateBitSetWidthInBytes(int numFields) { readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); } + ////////////////////////////////////////////////////////////////////////////// + // Private fields and methods + ////////////////////////////////////////////////////////////////////////////// + + private Object baseObject; + private long baseOffset; + + /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ + private int numFields; + + /** The size of this row's backing data, in bytes) */ + private int sizeInBytes; + + private void setNotNullAt(int i) { + assertIndexIsValid(i); + BitSetMethods.unset(baseObject, baseOffset, i); + } + + /** The width of the null tracking bit set, in bytes */ + private int bitSetWidthInBytes; + + private long getFieldOffset(int ordinal) { + return baseOffset + bitSetWidthInBytes + ordinal * 8L; + } + + ////////////////////////////////////////////////////////////////////////////// + // Public methods + ////////////////////////////////////////////////////////////////////////////// + /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, * since the value returned by this constructor is equivalent to a null pointer. */ public UnsafeRow() { } + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + + @Override + public int numFields() { return numFields; } + /** * Update this UnsafeRow to point to different backing data. * @@ -130,7 +147,7 @@ public UnsafeRow() { } * @param sizeInBytes the size of this row's backing data, in bytes */ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) { - assert numFields >= 0 : "numFields should >= 0"; + assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; @@ -153,11 +170,6 @@ public void setNullAt(int i) { PlatformDependent.UNSAFE.putLong(baseObject, getFieldOffset(i), 0); } - private void setNotNullAt(int i) { - assertIndexIsValid(i); - BitSetMethods.unset(baseObject, baseOffset, i); - } - @Override public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); @@ -316,6 +328,21 @@ public String getString(int i) { return getUTF8String(i).toString(); } + @Override + public UnsafeRow getStruct(int i, int numFields) { + if (isNullAt(i)) { + return null; + } else { + assertIndexIsValid(i); + final long offsetAndSize = getLong(i); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + final UnsafeRow row = new UnsafeRow(); + row.pointTo(baseObject, baseOffset + offset, numFields, size); + return row; + } + } + /** * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal * byte array rather than referencing data stored in a data page. @@ -388,7 +415,7 @@ public boolean equals(Object other) { */ public byte[] getBytes() { if (baseObject instanceof byte[] && baseOffset == PlatformDependent.BYTE_ARRAY_OFFSET - && (((byte[]) baseObject).length == sizeInBytes)) { + && (((byte[]) baseObject).length == sizeInBytes)) { return (byte[]) baseObject; } else { byte[] bytes = new byte[sizeInBytes]; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5c3072a77aeba..7416ddbaef3fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -271,7 +271,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Row = - toScala(row.get(column).asInstanceOf[InternalRow]) + toScala(row.getStruct(column, structType.size)) } private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index efc4faea569b2..f248b1f338acc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -52,6 +52,21 @@ abstract class InternalRow extends Serializable { def getDouble(i: Int): Double = getAs[Double](i) + def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + + def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + + // This is only use for test + def getString(i: Int): String = getAs[UTF8String](i).toString + + /** + * Returns a struct from ordinal position. + * + * @param ordinal position to get the struct from. + * @param numFields number of fields the struct type has + */ + def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal) + override def toString: String = s"[${this.mkString(",")}]" /** @@ -145,13 +160,6 @@ abstract class InternalRow extends Serializable { */ def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) - def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) - - def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) - - // This is only use for test - def getString(i: Int): String = getAs[UTF8String](i).toString - // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6aa4930cb8587..1f7adcd36ec14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) + case t: StructType => input.getStruct(ordinal, t.size) case _ => input.get(ordinal) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 48225e1574600..4a90f1b559896 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -109,6 +109,7 @@ class CodeGenContext { case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" case StringType => s"$row.getUTF8String($ordinal)" case BinaryType => s"$row.getBinary($ordinal)" + case t: StructType => s"$row.getStruct($ordinal, ${t.size})" case _ => s"($jt)$row.apply($ordinal)" } } @@ -249,13 +250,13 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected val mutableRowType: String = classOf[MutableRow].getName protected val genericMutableRowType: String = classOf[GenericMutableRow].getName - protected def declareMutableStates(ctx: CodeGenContext) = { + protected def declareMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map { case (javaType, variableName, _) => s"private $javaType $variableName;" }.mkString("\n ") } - protected def initMutableStates(ctx: CodeGenContext) = { + protected def initMutableStates(ctx: CodeGenContext): String = { ctx.mutableStates.map(_._3).mkString("\n ") } From c84acd4aa4f8bee98baa550cd6801c797a8a7a25 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 24 Jul 2015 19:35:24 -0700 Subject: [PATCH 057/219] [SPARK-9331][SQL] Add a code formatter to auto-format generated code. The generated expression code can be hard to read since they are not indented well. This patch adds a code formatter that formats code automatically when we output them to the screen. Author: Reynold Xin Closes #7656 from rxin/codeformatter and squashes the following commits: 5ba0e90 [Reynold Xin] [SPARK-9331][SQL] Add a code formatter to auto-format generated code. --- .../expressions/codegen/CodeFormatter.scala | 60 +++++++++++++++ .../expressions/codegen/CodeGenerator.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 11 +-- .../codegen/GenerateOrdering.scala | 2 +- .../codegen/GeneratePredicate.scala | 2 +- .../codegen/GenerateProjection.scala | 3 +- .../codegen/GenerateUnsafeProjection.scala | 2 +- .../codegen/CodeFormatterSuite.scala | 76 +++++++++++++++++++ 8 files changed, 148 insertions(+), 10 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala new file mode 100644 index 0000000000000..2087cc7f109bc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +/** + * An utility class that indents a block of code based on the curly braces. + * + * This is used to prettify generated code when in debug mode (or exceptions). + * + * Written by Matei Zaharia. + */ +object CodeFormatter { + def format(code: String): String = new CodeFormatter().addLines(code).result() +} + +private class CodeFormatter { + private val code = new StringBuilder + private var indentLevel = 0 + private val indentSize = 2 + private var indentString = "" + + private def addLine(line: String): Unit = { + val indentChange = line.count(_ == '{') - line.count(_ == '}') + val newIndentLevel = math.max(0, indentLevel + indentChange) + // Lines starting with '}' should be de-indented even if they contain '{' after; + // in addition, lines ending with ':' are typically labels + val thisLineIndent = if (line.startsWith("}") || line.endsWith(":")) { + " " * (indentSize * (indentLevel - 1)) + } else { + indentString + } + code.append(thisLineIndent) + code.append(line) + code.append("\n") + indentLevel = newIndentLevel + indentString = " " * (indentSize * newIndentLevel) + } + + private def addLines(code: String): CodeFormatter = { + code.split('\n').foreach(s => addLine(s.trim())) + this + } + + private def result(): String = code.result() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4a90f1b559896..508882acbee5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -299,7 +299,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin evaluator.cook(code) } catch { case e: Exception => - val msg = s"failed to compile:\n $code" + val msg = "failed to compile:\n " + CodeFormatter.format(code) logError(msg, e) throw new Exception(msg, e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d838268f46956..825031a4faf5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.codegen +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import scala.collection.mutable.ArrayBuffer - // MutableProjection is not accessible in Java abstract class BaseMutableProjection extends MutableProjection @@ -45,10 +45,11 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val evaluationCode = e.gen(ctx) evaluationCode.code + s""" - if(${evaluationCode.isNull}) + if (${evaluationCode.isNull}) { mutableRow.setNullAt($i); - else + } else { ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; + } """ } // collect projections into blocks as function has 64kb codesize limit in JVM @@ -119,7 +120,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu } """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) () => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 2e6f9e204d813..dbd4616d281c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -107,7 +107,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR } }""" - logDebug(s"Generated Ordering: $code") + logDebug(s"Generated Ordering: ${CodeFormatter.format(code)}") compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 1dda5992c3654..dfd593fb7c064 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -60,7 +60,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool } }""" - logDebug(s"Generated predicate '$predicate':\n$code") + logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] (r: InternalRow) => p.eval(r) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index f0efc4bff12ba..a361b216eb472 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -230,7 +230,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } """ - logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n${code}") + logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" + + CodeFormatter.format(code)) compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index d65e5c38ebf5c..0320bcb827bf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -114,7 +114,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala new file mode 100644 index 0000000000000..478702fea6146 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite + + +class CodeFormatterSuite extends SparkFunSuite { + + def testCase(name: String)(input: String)(expected: String): Unit = { + test(name) { + assert(CodeFormatter.format(input).trim === expected.trim) + } + } + + testCase("basic example") { + """ + |class A { + |blahblah; + |} + """.stripMargin + }{ + """ + |class A { + | blahblah; + |} + """.stripMargin + } + + testCase("nested example") { + """ + |class A { + | if (c) { + |duh; + |} + |} + """.stripMargin + } { + """ + |class A { + | if (c) { + | duh; + | } + |} + """.stripMargin + } + + testCase("single line") { + """ + |class A { + | if (c) {duh;} + |} + """.stripMargin + }{ + """ + |class A { + | if (c) {duh;} + |} + """.stripMargin + } +} From 19bcd6ab12bf355bc5d774905ec7fe3b5fc8e0e2 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Fri, 24 Jul 2015 22:57:01 -0700 Subject: [PATCH 058/219] [HOTFIX] - Disable Kinesis tests due to rate limits --- .../apache/spark/streaming/kinesis/KinesisStreamSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index f9c952b9468bb..4992b041765e9 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -59,7 +59,7 @@ class KinesisStreamSuite extends KinesisFunSuite } } - test("KinesisUtils API") { + ignore("KinesisUtils API") { ssc = new StreamingContext(sc, Seconds(1)) // Tests the API, does not actually test data receiving val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", @@ -83,7 +83,7 @@ class KinesisStreamSuite extends KinesisFunSuite * you must have AWS credentials available through the default AWS provider chain, * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ - testIfEnabled("basic operation") { + ignore("basic operation") { val kinesisTestUtils = new KinesisTestUtils() try { kinesisTestUtils.createStream() From 723db13e0688bf20e2a5f02ad170397c3a287712 Mon Sep 17 00:00:00 2001 From: JD Date: Sat, 25 Jul 2015 00:34:59 -0700 Subject: [PATCH 059/219] [Spark-8668][SQL] Adding expr to functions Author: JD Author: Joseph Batchik Closes #7606 from JDrit/expr and squashes the following commits: ad7f607 [Joseph Batchik] fixing python linter error 9d6daea [Joseph Batchik] removed order by per @rxin's comment 707d5c6 [Joseph Batchik] Added expr to fuctions.py 79df83c [JD] added example to the docs b89eec8 [JD] moved function up as per @rxin's comment 4960909 [JD] updated per @JoshRosen's comment 2cb329c [JD] updated per @rxin's comment 9a9ad0c [JD] removing unused import 6dc26d0 [JD] removed split 7f2222c [JD] Adding expr function as per SPARK-8668 --- python/pyspark/sql/functions.py | 10 ++++++++++ python/pyspark/sql/tests.py | 7 +++++++ .../scala/org/apache/spark/sql/functions.scala | 15 +++++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 11 +++++++++++ 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 719e623a1a11f..d930f7db25d25 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -541,6 +541,16 @@ def sparkPartitionId(): return Column(sc._jvm.functions.sparkPartitionId()) +def expr(str): + """Parses the expression string into the column that it represents + + >>> df.select(expr("length(name)")).collect() + [Row('length(name)=5), Row('length(name)=3)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.expr(str)) + + @ignore_unicode_prefix @since(1.5) def length(col): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ea821f486f13a..5aa6135dc1ee7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -846,6 +846,13 @@ def test_bitwise_operations(self): result = df.select(functions.bitwiseNOT(df.b)).collect()[0].asDict() self.assertEqual(~75, result['~b']) + def test_expr(self): + from pyspark.sql import functions + row = Row(a="length string", b=75) + df = self.sqlCtx.createDataFrame([row]) + result = df.select(functions.expr("length(a)")).collect()[0].asDict() + self.assertEqual(13, result["'length(a)"]) + def test_replace(self): schema = StructType([ StructField("name", StringType(), True), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bfeecbe8b2ab5..cab3db609dd4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint @@ -792,6 +792,18 @@ object functions { */ def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr) + /** + * Parses the expression string into the column that it represents, similar to + * DataFrame.selectExpr + * {{{ + * // get the number of words of each length + * df.groupBy(expr("length(word)")).count() + * }}} + * + * @group normal_funcs + */ + def expr(expr: String): Column = Column(new SqlParser().parseExpression(expr)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2451,5 +2463,4 @@ object functions { } UnresolvedFunction(udfName, exprs, isDistinct = false) } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 95a1106cf072d..cd386b7a3ecf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -112,6 +112,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil) } + test("SPARK-8668 expr function") { + checkAnswer(Seq((1, "Bobby G.")) + .toDF("id", "name") + .select(expr("length(name)"), expr("abs(id)")), Row(8, 1)) + + checkAnswer(Seq((1, "building burrito tunnels"), (1, "major projects")) + .toDF("id", "saying") + .groupBy(expr("length(saying)")) + .count(), Row(24, 1) :: Row(14, 1) :: Nil) + } + test("SQL Dialect Switching to a new SQL parser") { val newContext = new SQLContext(sqlContext.sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) From f0ebab3f6d3a9231474acf20110db72c0fb51882 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 25 Jul 2015 01:28:46 -0700 Subject: [PATCH 060/219] [SPARK-9336][SQL] Remove extra JoinedRows They were added to improve performance (so JIT can inline the JoinedRow calls). However, we can also just improve it by projecting output out to UnsafeRow in Tungsten variant of the operators. Author: Reynold Xin Closes #7659 from rxin/remove-joinedrows and squashes the following commits: 7510447 [Reynold Xin] [SPARK-9336][SQL] Remove extra JoinedRows --- .../sql/catalyst/expressions/Projection.scala | 494 +----------------- .../spark/sql/execution/Aggregate.scala | 2 +- .../sql/execution/GeneratedAggregate.scala | 2 +- .../apache/spark/sql/execution/Window.scala | 2 +- .../aggregate/sortBasedIterators.scala | 2 +- .../spark/sql/execution/joins/HashJoin.scala | 2 +- .../sql/execution/joins/SortMergeJoin.scala | 2 +- 7 files changed, 8 insertions(+), 498 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index dbda05a792cbf..6023a2c564389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -44,7 +44,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { new GenericInternalRow(outputArray) } - override def toString: String = s"Row => [${exprArray.mkString(",")}]" + override def toString(): String = s"Row => [${exprArray.mkString(",")}]" } /** @@ -58,7 +58,7 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu this(expressions.map(BindReferences.bindReference(_, inputSchema))) private[this] val exprArray = expressions.toArray - private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size) + private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) def currentValue: InternalRow = mutableRow override def target(row: MutableRow): MutableProjection = { @@ -186,496 +186,6 @@ class JoinedRow extends InternalRow { if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } - override def get(i: Int): Any = - if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) - - override def isNullAt(i: Int): Boolean = - if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) - - override def getInt(i: Int): Int = - if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) - - override def getLong(i: Int): Long = - if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) - - override def getDouble(i: Int): Double = - if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) - - override def getBoolean(i: Int): Boolean = - if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) - - override def getShort(i: Int): Short = - if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) - - override def getByte(i: Int): Byte = - if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) - - override def getFloat(i: Int): Float = - if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) - - override def copy(): InternalRow = { - val totalSize = row1.numFields + row2.numFields - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = get(i) - i += 1 - } - new GenericInternalRow(copiedValues) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} - -/** - * JIT HACK: Replace with macros - * The `JoinedRow` class is used in many performance critical situation. Unfortunately, since there - * are multiple different types of `Rows` that could be stored as `row1` and `row2` most of the - * calls in the critical path are polymorphic. By creating special versions of this class that are - * used in only a single location of the code, we increase the chance that only a single type of - * Row will be referenced, increasing the opportunity for the JIT to play tricks. This sounds - * crazy but in benchmarks it had noticeable effects. - */ -class JoinedRow2 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def numFields: Int = row1.numFields + row2.numFields - - override def getUTF8String(i: Int): UTF8String = { - if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - } - - override def getBinary(i: Int): Array[Byte] = { - if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - } - - override def get(i: Int): Any = - if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) - - override def isNullAt(i: Int): Boolean = - if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) - - override def getInt(i: Int): Int = - if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) - - override def getLong(i: Int): Long = - if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) - - override def getDouble(i: Int): Double = - if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) - - override def getBoolean(i: Int): Boolean = - if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) - - override def getShort(i: Int): Short = - if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) - - override def getByte(i: Int): Byte = - if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) - - override def getFloat(i: Int): Float = - if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) - - override def copy(): InternalRow = { - val totalSize = row1.numFields + row2.numFields - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = get(i) - i += 1 - } - new GenericInternalRow(copiedValues) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} - -/** - * JIT HACK: Replace with macros - */ -class JoinedRow3 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def numFields: Int = row1.numFields + row2.numFields - - override def getUTF8String(i: Int): UTF8String = { - if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - } - - override def getBinary(i: Int): Array[Byte] = { - if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - } - - - override def get(i: Int): Any = - if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) - - override def isNullAt(i: Int): Boolean = - if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) - - override def getInt(i: Int): Int = - if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) - - override def getLong(i: Int): Long = - if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) - - override def getDouble(i: Int): Double = - if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) - - override def getBoolean(i: Int): Boolean = - if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) - - override def getShort(i: Int): Short = - if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) - - override def getByte(i: Int): Byte = - if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) - - override def getFloat(i: Int): Float = - if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) - - override def copy(): InternalRow = { - val totalSize = row1.numFields + row2.numFields - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = get(i) - i += 1 - } - new GenericInternalRow(copiedValues) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} - -/** - * JIT HACK: Replace with macros - */ -class JoinedRow4 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def numFields: Int = row1.numFields + row2.numFields - - override def getUTF8String(i: Int): UTF8String = { - if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - } - - override def getBinary(i: Int): Array[Byte] = { - if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - } - - - override def get(i: Int): Any = - if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) - - override def isNullAt(i: Int): Boolean = - if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) - - override def getInt(i: Int): Int = - if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) - - override def getLong(i: Int): Long = - if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) - - override def getDouble(i: Int): Double = - if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) - - override def getBoolean(i: Int): Boolean = - if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) - - override def getShort(i: Int): Short = - if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) - - override def getByte(i: Int): Byte = - if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) - - override def getFloat(i: Int): Float = - if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) - - override def copy(): InternalRow = { - val totalSize = row1.numFields + row2.numFields - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = get(i) - i += 1 - } - new GenericInternalRow(copiedValues) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} - -/** - * JIT HACK: Replace with macros - */ -class JoinedRow5 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def numFields: Int = row1.numFields + row2.numFields - - override def getUTF8String(i: Int): UTF8String = { - if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - } - - override def getBinary(i: Int): Array[Byte] = { - if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - } - - - override def get(i: Int): Any = - if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) - - override def isNullAt(i: Int): Boolean = - if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) - - override def getInt(i: Int): Int = - if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) - - override def getLong(i: Int): Long = - if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) - - override def getDouble(i: Int): Double = - if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) - - override def getBoolean(i: Int): Boolean = - if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) - - override def getShort(i: Int): Short = - if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) - - override def getByte(i: Int): Byte = - if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) - - override def getFloat(i: Int): Float = - if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) - - override def copy(): InternalRow = { - val totalSize = row1.numFields + row2.numFields - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = get(i) - i += 1 - } - new GenericInternalRow(copiedValues) - } - - override def toString: String = { - // Make sure toString never throws NullPointerException. - if ((row1 eq null) && (row2 eq null)) { - "[ empty row ]" - } else if (row1 eq null) { - row2.mkString("[", ",", "]") - } else if (row2 eq null) { - row1.mkString("[", ",", "]") - } else { - mkString("[", ",", "]") - } - } -} - -/** - * JIT HACK: Replace with macros - */ -class JoinedRow6 extends InternalRow { - private[this] var row1: InternalRow = _ - private[this] var row2: InternalRow = _ - - def this(left: InternalRow, right: InternalRow) = { - this() - row1 = left - row2 = right - } - - /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ - def apply(r1: InternalRow, r2: InternalRow): InternalRow = { - row1 = r1 - row2 = r2 - this - } - - /** Updates this JoinedRow by updating its left base row. Returns itself. */ - def withLeft(newLeft: InternalRow): InternalRow = { - row1 = newLeft - this - } - - /** Updates this JoinedRow by updating its right base row. Returns itself. */ - def withRight(newRight: InternalRow): InternalRow = { - row2 = newRight - this - } - - override def toSeq: Seq[Any] = row1.toSeq ++ row2.toSeq - - override def numFields: Int = row1.numFields + row2.numFields - - override def getUTF8String(i: Int): UTF8String = { - if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - } - - override def getBinary(i: Int): Array[Byte] = { - if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - } - - override def get(i: Int): Any = if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index c2c945321db95..e8c6a0f8f801d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -172,7 +172,7 @@ case class Aggregate( private[this] val resultProjection = new InterpretedMutableProjection( resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow4 + private[this] val joinedRow = new JoinedRow override final def hasNext: Boolean = hashTableIter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 5ed158b3d2912..5ad4691a5ca07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -269,7 +269,7 @@ case class GeneratedAggregate( namedGroups.map(_._2) ++ computationSchema) log.info(s"Result Projection: ${resultExpressions.mkString(",")}") - val joinedRow = new JoinedRow3 + val joinedRow = new JoinedRow if (!iter.hasNext) { // This is an empty input, so return early so that we do not allocate data structures diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index de04132eb1104..91c8a02e2b5bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -298,7 +298,7 @@ case class Window( var rowsSize = 0 override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable - val join = new JoinedRow6 + val join = new JoinedRow val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) override final def next(): InternalRow = { // Load the next partition if we need to. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala index b8e95a5a2a4da..1b89edafa8dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -106,7 +106,7 @@ private[sql] abstract class SortAggregationIterator( new GenericMutableRow(size) } - protected val joinedRow = new JoinedRow4 + protected val joinedRow = new JoinedRow protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index ae34409bcfcca..46ab5b0d1cc6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -69,7 +69,7 @@ trait HashJoin { private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. - private[this] val joinRow = new JoinedRow2 + private[this] val joinRow = new JoinedRow private[this] val resultProjection: Projection = { if (supportUnsafe) { UnsafeProjection.create(self.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 981447eacad74..bb18b5403f8e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -66,7 +66,7 @@ case class SortMergeJoin( leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => new Iterator[InternalRow] { // Mutable per row objects. - private[this] val joinRow = new JoinedRow5 + private[this] val joinRow = new JoinedRow private[this] var leftElement: InternalRow = _ private[this] var rightElement: InternalRow = _ private[this] var leftKey: InternalRow = _ From 215713e19924dff69d226a97f1860a5470464d15 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 25 Jul 2015 01:37:41 -0700 Subject: [PATCH 061/219] [SPARK-9334][SQL] Remove UnsafeRowConverter in favor of UnsafeProjection. The two are redundant. Once this patch is merged, I plan to remove the inbound conversions from unsafe aggregates. Author: Reynold Xin Closes #7658 from rxin/unsafeconverters and squashes the following commits: ed19e6c [Reynold Xin] Updated support types. 2a56d7e [Reynold Xin] [SPARK-9334][SQL] Remove UnsafeRowConverter in favor of UnsafeProjection. --- .../UnsafeFixedWidthAggregationMap.java | 55 +--- .../expressions/UnsafeRowWriters.java | 83 ++++++ .../sql/catalyst/expressions/Projection.scala | 4 +- .../expressions/UnsafeRowConverter.scala | 276 ------------------ .../codegen/GenerateUnsafeProjection.scala | 131 ++++++--- .../expressions/ExpressionEvalHelper.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 79 ++--- .../execution/UnsafeRowSerializerSuite.scala | 17 +- .../apache/spark/unsafe/types/ByteArray.java | 38 +++ .../apache/spark/unsafe/types/UTF8String.java | 15 + 10 files changed, 262 insertions(+), 438 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala create mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 2f7e84a7f59e2..684de6e81d67c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -47,7 +47,7 @@ public final class UnsafeFixedWidthAggregationMap { /** * Encodes grouping keys as UnsafeRows. */ - private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + private final UnsafeProjection groupingKeyProjection; /** * A hashmap which maps from opaque bytearray keys to bytearray values. @@ -59,14 +59,6 @@ public final class UnsafeFixedWidthAggregationMap { */ private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); - /** - * Scratch space that is used when encoding grouping keys into UnsafeRow format. - * - * By default, this is a 8 kb array, but it will grow as necessary in case larger keys are - * encountered. - */ - private byte[] groupingKeyConversionScratchSpace = new byte[1024 * 8]; - private final boolean enablePerfMetrics; /** @@ -112,26 +104,17 @@ public UnsafeFixedWidthAggregationMap( TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { - this.emptyAggregationBuffer = - convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); this.aggregationBufferSchema = aggregationBufferSchema; - this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); + this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); this.enablePerfMetrics = enablePerfMetrics; - } - /** - * Convert a Java object row into an UnsafeRow, allocating it into a new byte array. - */ - private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) { - final UnsafeRowConverter converter = new UnsafeRowConverter(schema); - final int size = converter.getSizeRequirement(javaRow); - final byte[] unsafeRow = new byte[size]; - final int writtenLength = - converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET, size); - assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; - return unsafeRow; + // Initialize the buffer for aggregation value + final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema); + this.emptyAggregationBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes(); + assert(this.emptyAggregationBuffer.length == aggregationBufferSchema.length() * 8 + + UnsafeRow.calculateBitSetWidthInBytes(aggregationBufferSchema.length())); } /** @@ -139,30 +122,20 @@ private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) * return the same object. */ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { - final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); - // Make sure that the buffer is large enough to hold the key. If it's not, grow it: - if (groupingKeySize > groupingKeyConversionScratchSpace.length) { - groupingKeyConversionScratchSpace = new byte[groupingKeySize]; - } - final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( - groupingKey, - groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize); - assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; + final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey); // Probe our map using the serialized key final BytesToBytesMap.Location loc = map.lookup( - groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize); + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes()); if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: loc.putNewKey( - groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET, - groupingKeySize, + unsafeGroupingKeyRow.getBaseObject(), + unsafeGroupingKeyRow.getBaseOffset(), + unsafeGroupingKeyRow.getSizeInBytes(), emptyAggregationBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, emptyAggregationBuffer.length diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java new file mode 100644 index 0000000000000..87521d1f23c99 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.ByteArray; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A set of helper methods to write data into {@link UnsafeRow}s, + * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. + */ +public class UnsafeRowWriters { + + /** Writer for UTF8String. */ + public static class UTF8StringWriter { + + public static int getSize(UTF8String input) { + return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.numBytes()); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String input) { + final long offset = target.getBaseOffset() + cursor; + final int numBytes = input.numBytes(); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the string to the variable length portion. + input.writeToMemory(target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + + /** Writer for bianry (byte array) type. */ + public static class BinaryWriter { + + public static int getSize(byte[] input) { + return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) { + final long offset = target.getBaseOffset() + cursor; + final int numBytes = input.length; + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the string to the variable length portion. + ByteArray.writeToMemory(input, target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 6023a2c564389..fb873e7e99547 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -90,8 +90,10 @@ object UnsafeProjection { * Seq[Expression]. */ def canSupport(schema: StructType): Boolean = canSupport(schema.fields.map(_.dataType)) - def canSupport(types: Array[DataType]): Boolean = types.forall(UnsafeColumnWriter.canEmbed(_)) def canSupport(exprs: Seq[Expression]): Boolean = canSupport(exprs.map(_.dataType).toArray) + private def canSupport(types: Array[DataType]): Boolean = { + types.forall(GenerateUnsafeProjection.canSupport) + } /** * Returns an UnsafeProjection for given StructType. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala deleted file mode 100644 index c47b16c0f8585..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.util.Try - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent -import org.apache.spark.unsafe.array.ByteArrayMethods -import org.apache.spark.unsafe.types.UTF8String - - -/** - * Converts Rows into UnsafeRow format. This class is NOT thread-safe. - * - * @param fieldTypes the data types of the row's columns. - */ -class UnsafeRowConverter(fieldTypes: Array[DataType]) { - - def this(schema: StructType) { - this(schema.fields.map(_.dataType)) - } - - /** Re-used pointer to the unsafe row being written */ - private[this] val unsafeRow = new UnsafeRow() - - /** Functions for encoding each column */ - private[this] val writers: Array[UnsafeColumnWriter] = { - fieldTypes.map(t => UnsafeColumnWriter.forType(t)) - } - - /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */ - private[this] val fixedLengthSize: Int = - (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length) - - /** - * Compute the amount of space, in bytes, required to encode the given row. - */ - def getSizeRequirement(row: InternalRow): Int = { - var fieldNumber = 0 - var variableLengthFieldSize: Int = 0 - while (fieldNumber < writers.length) { - if (!row.isNullAt(fieldNumber)) { - variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber) - } - fieldNumber += 1 - } - fixedLengthSize + variableLengthFieldSize - } - - /** - * Convert the given row into UnsafeRow format. - * - * @param row the row to convert - * @param baseObject the base object of the destination address - * @param baseOffset the base offset of the destination address - * @param rowLengthInBytes the length calculated by `getSizeRequirement(row)` - * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. - */ - def writeRow( - row: InternalRow, - baseObject: Object, - baseOffset: Long, - rowLengthInBytes: Int): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, rowLengthInBytes) - - if (writers.length > 0) { - // zero-out the bitset - var n = writers.length / 64 - while (n >= 0) { - PlatformDependent.UNSAFE.putLong( - unsafeRow.getBaseObject, - unsafeRow.getBaseOffset + n * 8, - 0L) - n -= 1 - } - } - - var fieldNumber = 0 - var appendCursor: Int = fixedLengthSize - while (fieldNumber < writers.length) { - if (row.isNullAt(fieldNumber)) { - unsafeRow.setNullAt(fieldNumber) - } else { - appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) - } - fieldNumber += 1 - } - appendCursor - } - -} - -/** - * Function for writing a column into an UnsafeRow. - */ -private abstract class UnsafeColumnWriter { - /** - * Write a value into an UnsafeRow. - * - * @param source the row being converted - * @param target a pointer to the converted unsafe row - * @param column the column to write - * @param appendCursor the offset from the start of the unsafe row to the end of the row; - * used for calculating where variable-length data should be written - * @return the number of variable-length bytes written - */ - def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int - - /** - * Return the number of bytes that are needed to write this variable-length value. - */ - def getSize(source: InternalRow, column: Int): Int -} - -private object UnsafeColumnWriter { - - def forType(dataType: DataType): UnsafeColumnWriter = { - dataType match { - case NullType => NullUnsafeColumnWriter - case BooleanType => BooleanUnsafeColumnWriter - case ByteType => ByteUnsafeColumnWriter - case ShortType => ShortUnsafeColumnWriter - case IntegerType | DateType => IntUnsafeColumnWriter - case LongType | TimestampType => LongUnsafeColumnWriter - case FloatType => FloatUnsafeColumnWriter - case DoubleType => DoubleUnsafeColumnWriter - case StringType => StringUnsafeColumnWriter - case BinaryType => BinaryUnsafeColumnWriter - case t => - throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") - } - } - - /** - * Returns whether the dataType can be embedded into UnsafeRow (not using ObjectPool). - */ - def canEmbed(dataType: DataType): Boolean = Try(forType(dataType)).isSuccess -} - -// ------------------------------------------------------------------------------------------------ - -private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { - // Primitives don't write to the variable-length region: - def getSize(sourceRow: InternalRow, column: Int): Int = 0 -} - -private object NullUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setNullAt(column) - 0 - } -} - -private object BooleanUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setBoolean(column, source.getBoolean(column)) - 0 - } -} - -private object ByteUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setByte(column, source.getByte(column)) - 0 - } -} - -private object ShortUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setShort(column, source.getShort(column)) - 0 - } -} - -private object IntUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setInt(column, source.getInt(column)) - 0 - } -} - -private object LongUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setLong(column, source.getLong(column)) - 0 - } -} - -private object FloatUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setFloat(column, source.getFloat(column)) - 0 - } -} - -private object DoubleUnsafeColumnWriter extends PrimitiveUnsafeColumnWriter { - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - target.setDouble(column, source.getDouble(column)) - 0 - } -} - -private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { - - protected[this] def isString: Boolean - protected[this] def getBytes(source: InternalRow, column: Int): Array[Byte] - - override def getSize(source: InternalRow, column: Int): Int = { - val numBytes = getBytes(source, column).length - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - } - - override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { - val bytes = getBytes(source, column) - write(target, bytes, column, cursor) - } - - def write(target: UnsafeRow, bytes: Array[Byte], column: Int, cursor: Int): Int = { - val offset = target.getBaseOffset + cursor - val numBytes = bytes.length - if ((numBytes & 0x07) > 0) { - // zero-out the padding bytes - PlatformDependent.UNSAFE.putLong(target.getBaseObject, offset + ((numBytes >> 3) << 3), 0L) - } - PlatformDependent.copyMemory( - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - target.getBaseObject, - offset, - numBytes - ) - target.setLong(column, (cursor.toLong << 32) | numBytes.toLong) - ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) - } -} - -private object StringUnsafeColumnWriter extends BytesUnsafeColumnWriter { - protected[this] def isString: Boolean = true - def getBytes(source: InternalRow, column: Int): Array[Byte] = { - source.getAs[UTF8String](column).getBytes - } - // TODO(davies): refactor this - // specialized for codegen - def getSize(value: UTF8String): Int = - ByteArrayMethods.roundNumberOfBytesToNearestWord(value.numBytes()) - def write(target: UnsafeRow, value: UTF8String, column: Int, cursor: Int): Int = { - write(target, value.getBytes, column, cursor) - } -} - -private object BinaryUnsafeColumnWriter extends BytesUnsafeColumnWriter { - protected[this] override def isString: Boolean = false - override def getBytes(source: InternalRow, column: Int): Array[Byte] = { - source.getAs[Array[Byte]](column) - } - // specialized for codegen - def getSize(value: Array[Byte]): Int = - ByteArrayMethods.roundNumberOfBytesToNearestWord(value.length) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 0320bcb827bf7..afd0d9cfa1ddd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{NullType, BinaryType, StringType} - +import org.apache.spark.sql.types._ /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -32,25 +31,43 @@ import org.apache.spark.sql.types.{NullType, BinaryType, StringType} */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { - protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer.execute) + private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName + private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + /** Returns true iff we support this data type. */ + def canSupport(dataType: DataType): Boolean = dataType match { + case t: AtomicType if !t.isInstanceOf[DecimalType] => true + case NullType => true + case _ => false + } + + /** + * Generates the code to create an [[UnsafeRow]] object based on the input expressions. + * @param ctx context for code generation + * @param ev specifies the name of the variable for the output [[UnsafeRow]] object + * @param expressions input expressions + * @return generated code to put the expression output into an [[UnsafeRow]] + */ + def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression]) + : String = { + + val ret = ev.primitive + ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") + val bufferTerm = ctx.freshName("buffer") + ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];") + val cursorTerm = ctx.freshName("cursor") + val numBytesTerm = ctx.freshName("numBytes") - protected def create(expressions: Seq[Expression]): UnsafeProjection = { - val ctx = newCodeGenContext() val exprs = expressions.map(_.gen(ctx)) val allExprs = exprs.map(_.code).mkString("\n") val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val stringWriter = "org.apache.spark.sql.catalyst.expressions.StringUnsafeColumnWriter" - val binaryWriter = "org.apache.spark.sql.catalyst.expressions.BinaryUnsafeColumnWriter" + val additionalSize = expressions.zipWithIndex.map { case (e, i) => e.dataType match { case StringType => - s" + (${exprs(i).isNull} ? 0 : $stringWriter.getSize(${exprs(i).primitive}))" + s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" case BinaryType => - s" + (${exprs(i).isNull} ? 0 : $binaryWriter.getSize(${exprs(i).primitive}))" + s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" case _ => "" } }.mkString("") @@ -58,63 +75,85 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val writers = expressions.zipWithIndex.map { case (e, i) => val update = e.dataType match { case dt if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn("target", dt, i, exprs(i).primitive)}" + s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" case StringType => - s"cursor += $stringWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case BinaryType => - s"cursor += $binaryWriter.write(target, ${exprs(i).primitive}, $i, cursor)" + s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") } s"""if (${exprs(i).isNull}) { - target.setNullAt($i); + $ret.setNullAt($i); } else { $update; }""" }.mkString("\n ") - val code = s""" - private $exprType[] expressions; + s""" + $allExprs + int $numBytesTerm = $fixedSize $additionalSize; + if ($numBytesTerm > $bufferTerm.length) { + $bufferTerm = new byte[$numBytesTerm]; + } - public Object generate($exprType[] expr) { - this.expressions = expr; - return new SpecificProjection(); - } + $ret.pointTo( + $bufferTerm, + org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + ${expressions.size}, + $numBytesTerm); + int $cursorTerm = $fixedSize; - class SpecificProjection extends ${classOf[UnsafeProjection].getName} { - private UnsafeRow target = new UnsafeRow(); - private byte[] buffer = new byte[64]; - ${declareMutableStates(ctx)} + $writers + boolean ${ev.isNull} = false; + """ + } - public SpecificProjection() { - ${initMutableStates(ctx)} - } + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = + in.map(ExpressionCanonicalizer.execute) + + protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = + in.map(BindReferences.bindReference(_, inputSchema)) + + protected def create(expressions: Seq[Expression]): UnsafeProjection = { + val ctx = newCodeGenContext() + + val isNull = ctx.freshName("retIsNull") + val primitive = ctx.freshName("retValue") + val eval = GeneratedExpressionCode("", isNull, primitive) + eval.code = createCode(ctx, eval, expressions) - // Scala.Function1 need this - public Object apply(Object row) { - return apply((InternalRow) row); + val code = s""" + private $exprType[] expressions; + + public Object generate($exprType[] expr) { + this.expressions = expr; + return new SpecificProjection(); } - public UnsafeRow apply(InternalRow i) { - $allExprs + class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + + ${declareMutableStates(ctx)} + + public SpecificProjection() { + ${initMutableStates(ctx)} + } + + // Scala.Function1 need this + public Object apply(Object row) { + return apply((InternalRow) row); + } - // additionalSize had '+' in the beginning - int numBytes = $fixedSize $additionalSize; - if (numBytes > buffer.length) { - buffer = new byte[numBytes]; + public UnsafeRow apply(InternalRow i) { + ${eval.code} + return ${eval.primitive}; } - target.pointTo(buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, - ${expressions.size}, numBytes); - int cursor = $fixedSize; - $writers - return target; } - } - """ + """ - logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") + logDebug(s"code for ${expressions.mkString(",")}:\n$code") val c = compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 6e17ffcda9dc4..4930219aa63cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -43,7 +43,7 @@ trait ExpressionEvalHelper { checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) - if (UnsafeColumnWriter.canEmbed(expression.dataType)) { + if (GenerateUnsafeProjection.canSupport(expression.dataType)) { checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) } checkEvaluationWithOptimization(expression, catalystValue, inputRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index a5d9806c20463..4606bcb57311d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String @@ -34,22 +33,15 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.setLong(1, 1) row.setInt(2, 2) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (3 * 8)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) - assert(numBytesWritten === sizeRequired) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) + val unsafeRow: UnsafeRow = converter.apply(row) + assert(converter.apply(row).getSizeInBytes === 8 + (3 * 8)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) @@ -73,25 +65,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with primitive, string and binary types") { val fieldTypes: Array[DataType] = Array(LongType, StringType, BinaryType) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) row.update(1, UTF8String.fromString("Hello")) row.update(2, "World".getBytes) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (8 * 3) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow( - row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) - assert(numBytesWritten === sizeRequired) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) + assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") assert(unsafeRow.getBinary(2) === "World".getBytes) @@ -99,7 +84,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { test("basic conversion with primitive, string, date and timestamp types") { val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) @@ -107,17 +92,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))) row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) - val sizeRequired: Int = converter.getSizeRequirement(row) - assert(sizeRequired === 8 + (8 * 4) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) - val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = - converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, sizeRequired) - assert(numBytesWritten === sizeRequired) - - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo( - buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) + assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow @@ -148,26 +126,18 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // DecimalType.Default, // ArrayType(IntegerType) ) - val converter = new UnsafeRowConverter(fieldTypes) + val converter = UnsafeProjection.create(fieldTypes) val rowWithAllNullColumns: InternalRow = { val r = new SpecificMutableRow(fieldTypes) - for (i <- 0 to fieldTypes.length - 1) { + for (i <- fieldTypes.indices) { r.setNullAt(i) } r } - val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) - val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired) - assert(numBytesWritten === sizeRequired) + val createdFromNull: UnsafeRow = converter.apply(rowWithAllNullColumns) - val createdFromNull = new UnsafeRow() - createdFromNull.pointTo( - createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, sizeRequired) for (i <- fieldTypes.indices) { assert(createdFromNull.isNullAt(i)) } @@ -202,15 +172,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { // r.update(11, Array(11)) r } - val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) - converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, - sizeRequired) - val setToNullAfterCreation = new UnsafeRow() - setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, - sizeRequired) + val setToNullAfterCreation = converter.apply(rowWithNoNullColumns) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) @@ -228,8 +191,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setNullAt(i) } // There are some garbage left in the var-length area - assert(Arrays.equals(createdFromNullBuffer, - java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8))) + assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes())) setToNullAfterCreation.setNullAt(0) setToNullAfterCreation.setBoolean(1, false) @@ -269,12 +231,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - val converter = new UnsafeRowConverter(fieldTypes) - val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1)) - val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2)) - converter.writeRow(row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length) - converter.writeRow(row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length) - - assert(row1Buffer.toSeq === row2Buffer.toSeq) + val converter = UnsafeProjection.create(fieldTypes) + assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index a1e1695717e23..40b47ae18d648 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -22,29 +22,22 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeRowConverter} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.PlatformDependent class UnsafeRowSerializerSuite extends SparkFunSuite { private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = { val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow] - val rowConverter = new UnsafeRowConverter(schema) - val rowSizeInBytes = rowConverter.getSizeRequirement(internalRow) - val byteArray = new Array[Byte](rowSizeInBytes) - rowConverter.writeRow( - internalRow, byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, rowSizeInBytes) - val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(byteArray, PlatformDependent.BYTE_ARRAY_OFFSET, row.length, rowSizeInBytes) - unsafeRow + val converter = UnsafeProjection.create(schema) + converter.apply(internalRow) } - ignore("toUnsafeRow() test helper method") { + test("toUnsafeRow() test helper method") { // This currently doesnt work because the generic getter throws an exception. val row = Row("Hello", 123) val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType)) - assert(row.getString(0) === unsafeRow.get(0).toString) + assert(row.getString(0) === unsafeRow.getUTF8String(0).toString) assert(row.getInt(1) === unsafeRow.getInt(1)) } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java new file mode 100644 index 0000000000000..69b0e206cef18 --- /dev/null +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.unsafe.types; + +import org.apache.spark.unsafe.PlatformDependent; + +public class ByteArray { + + /** + * Writes the content of a byte array into a memory address, identified by an object and an + * offset. The target memory address must already been allocated, and have enough space to + * hold all the bytes in this string. + */ + public static void writeToMemory(byte[] src, Object target, long targetOffset) { + PlatformDependent.copyMemory( + src, + PlatformDependent.BYTE_ARRAY_OFFSET, + target, + targetOffset, + src.length + ); + } +} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 6d8dcb1cbf876..85381cf0ef425 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -95,6 +95,21 @@ protected UTF8String(Object base, long offset, int size) { this.numBytes = size; } + /** + * Writes the content of this string into a memory address, identified by an object and an offset. + * The target memory address must already been allocated, and have enough space to hold all the + * bytes in this string. + */ + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + base, + offset, + target, + targetOffset, + numBytes + ); + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point From c980e20cf17f2980c564beab9b241022872e29ea Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 25 Jul 2015 11:05:08 +0100 Subject: [PATCH 062/219] [SPARK-9304] [BUILD] Improve backwards compatibility of SPARK-8401 Add back change-version-to-X.sh scripts, as wrappers for new script, for backwards compatibility Author: Sean Owen Closes #7639 from srowen/SPARK-9304 and squashes the following commits: 9ab2681 [Sean Owen] Add deprecation message to wrappers 3c8c202 [Sean Owen] Add back change-version-to-X.sh scripts, as wrappers for new script, for backwards compatibility --- dev/change-version-to-2.10.sh | 23 +++++++++++++++++++++++ dev/change-version-to-2.11.sh | 23 +++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100755 dev/change-version-to-2.10.sh create mode 100755 dev/change-version-to-2.11.sh diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh new file mode 100755 index 0000000000000..0962d34c52f28 --- /dev/null +++ b/dev/change-version-to-2.10.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script exists for backwards compability. Use change-scala-version.sh instead. +echo "This script is deprecated. Please instead run: change-scala-version.sh 2.10" + +$(dirname $0)/change-scala-version.sh 2.10 diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh new file mode 100755 index 0000000000000..4ccfeef09fd04 --- /dev/null +++ b/dev/change-version-to-2.11.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This script exists for backwards compability. Use change-scala-version.sh instead. +echo "This script is deprecated. Please instead run: change-scala-version.sh 2.11" + +$(dirname $0)/change-scala-version.sh 2.11 From e2ec018e37cb699077b5fa2bd662f2055cb42296 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 25 Jul 2015 11:42:49 -0700 Subject: [PATCH 063/219] [SPARK-9285] [SQL] Fixes Row/InternalRow conversion for HadoopFsRelation This is a follow-up of #7626. It fixes `Row`/`InternalRow` conversion for data sources extending `HadoopFsRelation` with `needConversion` being `true`. Author: Cheng Lian Closes #7649 from liancheng/spark-9285-conversion-fix and squashes the following commits: 036a50c [Cheng Lian] Addresses PR comment f6d7c6a [Cheng Lian] Fixes Row/InternalRow conversion for HadoopFsRelation --- .../apache/spark/sql/sources/interfaces.scala | 23 ++++++++++++++++--- .../SimpleTextHadoopFsRelationSuite.scala | 5 ---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 119bac786d478..7126145ddc010 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -28,7 +28,7 @@ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.RDDConversions @@ -593,6 +593,11 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio * * @since 1.4.0 */ + // TODO Tries to eliminate the extra Catalyst-to-Scala conversion when `needConversion` is true + // + // PR #7626 separated `Row` and `InternalRow` completely. One of the consequences is that we can + // no longer treat an `InternalRow` containing Catalyst values as a `Row`. Thus we have to + // introduce another row value conversion for data sources whose `needConversion` is true. def buildScan(requiredColumns: Array[String], inputFiles: Array[FileStatus]): RDD[Row] = { // Yeah, to workaround serialization... val dataSchema = this.dataSchema @@ -611,14 +616,26 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio } else { rdd.asInstanceOf[RDD[InternalRow]] } + converted.mapPartitions { rows => val buildProjection = if (codegenEnabled) { GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) } else { () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) } - val mutableProjection = buildProjection() - rows.map(r => mutableProjection(r)) + + val projectedRows = { + val mutableProjection = buildProjection() + rows.map(r => mutableProjection(r)) + } + + if (needConversion) { + val requiredSchema = StructType(requiredColumns.map(dataSchema(_))) + val toScala = CatalystTypeConverters.createToScalaConverter(requiredSchema) + projectedRows.map(toScala(_).asInstanceOf[Row]) + } else { + projectedRows + } }.asInstanceOf[RDD[Row]] } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index d761909d60e21..e8975e5f5cd08 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -22,10 +22,6 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -/* -This is commented out due a bug in the data source API (SPARK-9291). - - class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName @@ -54,4 +50,3 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } } -*/ From 2c94d0f24a37fa079b56d534b0b0a4574209215b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 25 Jul 2015 12:10:02 -0700 Subject: [PATCH 064/219] [SPARK-9192][SQL] add initialization phase for nondeterministic expression Currently nondeterministic expression is broken without a explicit initialization phase. Let me take `MonotonicallyIncreasingID` as an example. This expression need a mutable state to remember how many times it has been evaluated, so we use `transient var count: Long` there. By being transient, the `count` will be reset to 0 and **only** to 0 when serialize and deserialize it, as deserialize transient variable will result to default value. There is *no way* to use another initial value for `count`, until we add the explicit initialization phase. Another use case is local execution for `LocalRelation`, there is no serialize and deserialize phase and thus we can't reset mutable states for it. Author: Wenchen Fan Closes #7535 from cloud-fan/init and squashes the following commits: 6c6f332 [Wenchen Fan] add test ef68ff4 [Wenchen Fan] fix comments 9eac85e [Wenchen Fan] move init code to interpreted class bb7d838 [Wenchen Fan] pulls out nondeterministic expressions into a project b4a4fc7 [Wenchen Fan] revert a refactor 86fee36 [Wenchen Fan] add initialization phase for nondeterministic expression --- .../sql/catalyst/analysis/Analyzer.scala | 35 +++++- .../sql/catalyst/analysis/CheckAnalysis.scala | 19 +++- .../sql/catalyst/expressions/Expression.scala | 21 +++- .../sql/catalyst/expressions/Projection.scala | 10 ++ .../sql/catalyst/expressions/predicates.scala | 4 + .../sql/catalyst/expressions/random.scala | 12 +- .../plans/logical/basicOperators.scala | 3 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 96 +++++++--------- .../sql/catalyst/analysis/AnalysisTest.scala | 105 ++++++++++++++++++ .../expressions/ExpressionEvalHelper.scala | 4 + .../MonotonicallyIncreasingID.scala | 13 ++- .../expressions/SparkPartitionID.scala | 8 +- 12 files changed, 254 insertions(+), 76 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e916887187dc8..a723e92114b32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ import scala.collection.mutable.ArrayBuffer @@ -78,7 +79,9 @@ class Analyzer( GlobalAggregates :: UnresolvedHavingClauseAttributes :: HiveTypeCoercion.typeCoercionRules ++ - extendedResolutionRules : _*) + extendedResolutionRules : _*), + Batch("Nondeterministic", Once, + PullOutNondeterministic) ) /** @@ -910,6 +913,34 @@ class Analyzer( Project(finalProjectList, withWindow) } } + + /** + * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter, + * put them into an inner Project and finally project them away at the outer Project. + */ + object PullOutNondeterministic extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Project => p + case f: Filter => f + + // todo: It's hard to write a general rule to pull out nondeterministic expressions + // from LogicalPlan, currently we only do it for UnaryNode which has same output + // schema with its child. + case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => + val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")() + } + new TreeNodeRef(e) -> ne + }.toMap + val newPlan = p.transformExpressions { case e => + nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) + } + val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child) + Project(p.output, newPlan.withNewChildren(newChild :: Nil)) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 81d473c1130f7..a373714832962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -38,10 +37,10 @@ trait CheckAnalysis { throw new AnalysisException(msg) } - def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { + protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { exprs.flatMap(_.collect { - case e: Generator => true - }).nonEmpty + case e: Generator => e + }).length > 1 } def checkAnalysis(plan: LogicalPlan): Unit = { @@ -137,13 +136,21 @@ trait CheckAnalysis { s""" |Failure when resolving conflicting references in Join: |$plan - |Conflicting attributes: ${conflictingAttributes.mkString(",")} - |""".stripMargin) + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") + case o if o.expressions.exists(!_.deterministic) && + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + failAnalysis( + s"""nondeterministic expressions are only allowed in Project or Filter, found: + | ${o.expressions.map(_.prettyString).mkString(",")} + |in operator ${operator.simpleString} + """.stripMargin) + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3f72e6e184db1..cb4c3f24b2721 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -196,7 +196,26 @@ trait Unevaluable extends Expression { * An expression that is nondeterministic. */ trait Nondeterministic extends Expression { - override def deterministic: Boolean = false + final override def deterministic: Boolean = false + final override def foldable: Boolean = false + + private[this] var initialized = false + + final def initialize(): Unit = { + if (!initialized) { + initInternal() + initialized = true + } + } + + protected def initInternal(): Unit + + final override def eval(input: InternalRow = null): Any = { + require(initialized, "nondeterministic expression should be initialized before evaluate") + evalInternal(input) + } + + protected def evalInternal(input: InternalRow): Any } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index fb873e7e99547..c1ed9cf7ed6a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -31,6 +31,11 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) + // null check is required for when Kryo invokes the no-arg constructor. protected val exprArray = if (expressions != null) expressions.toArray else null @@ -57,6 +62,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + expressions.foreach(_.foreach { + case n: Nondeterministic => n.initialize() + case _ => + }) + private[this] val exprArray = expressions.toArray private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length) def currentValue: InternalRow = mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3f1bd2a925fe7..5bfe1cad24a3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -30,6 +30,10 @@ object InterpretedPredicate { create(BindReferences.bindReference(expression, inputSchema)) def create(expression: Expression): (InternalRow => Boolean) = { + expression.foreach { + case n: Nondeterministic => n.initialize() + case _ => + } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index aef24a5486466..8f30519697a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -38,9 +38,13 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** * Record ID within each partition. By being transient, the Random Number Generator is - * reset every time we serialize and deserialize it. + * reset every time we serialize and deserialize and initialize it. */ - @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + @transient protected var rng: XORShiftRandom = _ + + override protected def initInternal(): Unit = { + rng = new XORShiftRandom(seed + TaskContext.getPartitionId) + } override def nullable: Boolean = false @@ -49,7 +53,7 @@ abstract class RDG extends LeafExpression with Nondeterministic { /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ case class Rand(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextDouble() + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() def this() = this(Utils.random.nextLong()) @@ -72,7 +76,7 @@ case class Rand(seed: Long) extends RDG { /** Generate a random column with i.i.d. gaussian random distribution. */ case class Randn(seed: Long) extends RDG { - override def eval(input: InternalRow): Double = rng.nextGaussian() + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() def this() = this(Utils.random.nextLong()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 57a12820fa4c6..8e1a236e2988c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -379,7 +378,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override lazy val statistics: Statistics = { - val limit = limitExpr.eval(null).asInstanceOf[Int] + val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum Statistics(sizeInBytes = sizeInBytes) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7e67427237a65..ed645b618dc9b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.BeforeAndAfter - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -28,6 +24,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +// todo: remove this and use AnalysisTest instead. object AnalysisSuite { val caseSensitiveConf = new SimpleCatalystConf(true) val caseInsensitiveConf = new SimpleCatalystConf(false) @@ -55,7 +52,7 @@ object AnalysisSuite { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(), + AttributeReference("d", DecimalType(10, 2))(), AttributeReference("e", ShortType)()) val nestedRelation = LocalRelation( @@ -81,8 +78,7 @@ object AnalysisSuite { } -class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { - import AnalysisSuite._ +class AnalysisSuite extends AnalysisTest { test("union project *") { val plan = (1 to 100) @@ -91,7 +87,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyzer.execute(plan).resolved) + assertAnalysisSuccess(plan) } test("check project's resolved") { @@ -106,61 +102,40 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { } test("analyze project") { - assert( - caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) === - Project(testRelation.output, testRelation)) - - assert( - caseSensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - val e = intercept[AnalysisException] { - caseSensitiveAnalyze( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) - } - assert(e.getMessage().toLowerCase.contains("cannot resolve")) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("TbL.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) - - assert( - caseInsensitiveAnalyzer.execute( - Project(Seq(UnresolvedAttribute("tBl.a")), - UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === - Project(testRelation.output, testRelation)) + checkAnalysis( + Project(Seq(UnresolvedAttribute("a")), testRelation), + Project(testRelation.output, testRelation)) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation)) + + assertAnalysisError( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Seq("cannot resolve")) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) + + checkAnalysis( + Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))), + Project(testRelation.output, testRelation), + caseSensitive = false) } test("resolve relations") { - val e = intercept[RuntimeException] { - caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) - } - assert(e.getMessage == "Table Not Found: tAbLe") + assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe")) - assert( - caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false) - assert( - caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) + checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false) } - test("divide should be casted into fractional types") { - val testRelation2 = LocalRelation( - AttributeReference("a", StringType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType(10, 2))(), - AttributeReference("e", ShortType)()) - val plan = caseInsensitiveAnalyzer.execute( testRelation2.select( 'a / Literal(2) as 'div1, @@ -170,10 +145,21 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList + // StringType will be promoted into Double assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DoubleType) // StringType will be promoted into Double + assert(pl(3).dataType == DoubleType) assert(pl(4).dataType == DoubleType) } + + test("pull out nondeterministic expressions from unary LogicalPlan") { + val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) + val projected = Alias(Rand(33), "_nondeterministic")() + val expected = + Project(testRelation.output, + RepartitionByExpression(Seq(projected.toAttribute), + Project(testRelation.output :+ projected, testRelation))) + checkAnalysis(plan, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala new file mode 100644 index 0000000000000..fdb4f28950daf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.types._ + +trait AnalysisTest extends PlanTest { + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + + val testRelation2 = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(10, 2))(), + AttributeReference("e", ShortType)()) + + val nestedRelation = LocalRelation( + AttributeReference("top", StructType( + StructField("duplicateField", StringType) :: + StructField("duplicateField", StringType) :: + StructField("differentCase", StringType) :: + StructField("differentcase", StringType) :: Nil + ))()) + + val nestedRelation2 = LocalRelation( + AttributeReference("top", StructType( + StructField("aField", StringType) :: + StructField("bField", StringType) :: + StructField("cField", StringType) :: Nil + ))()) + + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + + val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = { + val caseSensitiveConf = new SimpleCatalystConf(true) + val caseInsensitiveConf = new SimpleCatalystConf(false) + + val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf) + val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf) + + caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) + + new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } -> + new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) { + override val extendedResolutionRules = EliminateSubQueries :: Nil + } + } + + protected def getAnalyzer(caseSensitive: Boolean) = { + if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer + } + + protected def checkAnalysis( + inputPlan: LogicalPlan, + expectedPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + val actualPlan = analyzer.execute(inputPlan) + analyzer.checkAnalysis(actualPlan) + comparePlans(actualPlan, expectedPlan) + } + + protected def assertAnalysisSuccess( + inputPlan: LogicalPlan, + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + + protected def assertAnalysisError( + inputPlan: LogicalPlan, + expectedErrors: Seq[String], + caseSensitive: Boolean = true): Unit = { + val analyzer = getAnalyzer(caseSensitive) + // todo: make sure we throw AnalysisException during analysis + val e = intercept[Exception] { + analyzer.checkAnalysis(analyzer.execute(inputPlan)) + } + expectedErrors.forall(e.getMessage.contains) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 4930219aa63cb..852a8b235f127 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -64,6 +64,10 @@ trait ExpressionEvalHelper { } protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { + expression.foreach { + case n: Nondeterministic => n.initialize() + case _ => + } expression.eval(inputRow) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 2645eb1854bce..eca36b3274420 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -37,17 +37,22 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with /** * Record ID within each partition. By being transient, count's value is reset to 0 every time - * we serialize and deserialize it. + * we serialize and deserialize and initialize it. */ - @transient private[this] var count: Long = 0L + @transient private[this] var count: Long = _ - @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33 + @transient private[this] var partitionMask: Long = _ + + override protected def initInternal(): Unit = { + count = 0L + partitionMask = TaskContext.getPartitionId().toLong << 33 + } override def nullable: Boolean = false override def dataType: DataType = LongType - override def eval(input: InternalRow): Long = { + override protected def evalInternal(input: InternalRow): Long = { val currentCount = count count += 1 partitionMask + currentCount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 53ddd47e3e0c1..61ef079d89af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -33,9 +33,13 @@ private[sql] case object SparkPartitionID extends LeafExpression with Nondetermi override def dataType: DataType = IntegerType - @transient private lazy val partitionId = TaskContext.getPartitionId() + @transient private[this] var partitionId: Int = _ - override def eval(input: InternalRow): Int = partitionId + override protected def initInternal(): Unit = { + partitionId = TaskContext.getPartitionId() + } + + override protected def evalInternal(input: InternalRow): Int = partitionId override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val idTerm = ctx.freshName("partitionId") From b1f4b4abfd8d038c3684685b245b5fd31b927da0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 25 Jul 2015 18:41:51 -0700 Subject: [PATCH 065/219] [SPARK-9348][SQL] Remove apply method on InternalRow. Author: Reynold Xin Closes #7665 from rxin/remove-row-apply and squashes the following commits: 0b43001 [Reynold Xin] support getString in UnsafeRow. 176d633 [Reynold Xin] apply -> get. 2941324 [Reynold Xin] [SPARK-9348][SQL] Remove apply method on InternalRow. --- .../sql/catalyst/expressions/UnsafeRow.java | 88 +++++++++---------- .../spark/sql/catalyst/InternalRow.scala | 32 +++---- .../expressions/codegen/CodeGenerator.scala | 2 +- .../expressions/MathFunctionsSuite.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 4 +- .../spark/sql/columnar/ColumnType.scala | 16 ++-- .../compression/compressionSchemes.scala | 2 +- .../sql/execution/SparkSqlSerializer2.scala | 4 +- .../datasources/DataSourceStrategy.scala | 6 +- .../spark/sql/execution/debug/package.scala | 2 +- .../spark/sql/execution/pythonUDFs.scala | 2 +- .../sql/expressions/aggregate/udaf.scala | 4 +- .../sql/parquet/ParquetTableOperations.scala | 6 +- .../sql/parquet/ParquetTableSupport.scala | 22 ++--- .../scala/org/apache/spark/sql/RowSuite.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 12 +-- .../NullableColumnAccessorSuite.scala | 2 +- .../columnar/NullableColumnBuilderSuite.scala | 2 +- .../compression/BooleanBitSetSuite.scala | 2 +- .../spark/sql/hive/HiveInspectors.scala | 6 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- .../spark/sql/hive/orc/OrcRelation.scala | 2 +- 22 files changed, 113 insertions(+), 111 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 225f6e6553d19..9be9089493335 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -231,84 +231,89 @@ public void setFloat(int ordinal, float value) { } @Override - public Object get(int i) { + public Object get(int ordinal) { throw new UnsupportedOperationException(); } @Override - public T getAs(int i) { + public T getAs(int ordinal) { throw new UnsupportedOperationException(); } @Override - public boolean isNullAt(int i) { - assertIndexIsValid(i); - return BitSetMethods.isSet(baseObject, baseOffset, i); + public boolean isNullAt(int ordinal) { + assertIndexIsValid(ordinal); + return BitSetMethods.isSet(baseObject, baseOffset, ordinal); } @Override - public boolean getBoolean(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(i)); + public boolean getBoolean(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getBoolean(baseObject, getFieldOffset(ordinal)); } @Override - public byte getByte(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(i)); + public byte getByte(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getByte(baseObject, getFieldOffset(ordinal)); } @Override - public short getShort(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(i)); + public short getShort(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getShort(baseObject, getFieldOffset(ordinal)); } @Override - public int getInt(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(i)); + public int getInt(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getInt(baseObject, getFieldOffset(ordinal)); } @Override - public long getLong(int i) { - assertIndexIsValid(i); - return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + public long getLong(int ordinal) { + assertIndexIsValid(ordinal); + return PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(ordinal)); } @Override - public float getFloat(int i) { - assertIndexIsValid(i); - if (isNullAt(i)) { + public float getFloat(int ordinal) { + assertIndexIsValid(ordinal); + if (isNullAt(ordinal)) { return Float.NaN; } else { - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(i)); + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); } } @Override - public double getDouble(int i) { - assertIndexIsValid(i); - if (isNullAt(i)) { + public double getDouble(int ordinal) { + assertIndexIsValid(ordinal); + if (isNullAt(ordinal)) { return Float.NaN; } else { - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(i)); + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } } @Override - public UTF8String getUTF8String(int i) { - assertIndexIsValid(i); - return isNullAt(i) ? null : UTF8String.fromBytes(getBinary(i)); + public UTF8String getUTF8String(int ordinal) { + assertIndexIsValid(ordinal); + return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal)); } @Override - public byte[] getBinary(int i) { - if (isNullAt(i)) { + public String getString(int ordinal) { + return getUTF8String(ordinal).toString(); + } + + @Override + public byte[] getBinary(int ordinal) { + if (isNullAt(ordinal)) { return null; } else { - assertIndexIsValid(i); - final long offsetAndSize = getLong(i); + assertIndexIsValid(ordinal); + final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final byte[] bytes = new byte[size]; @@ -324,17 +329,12 @@ public byte[] getBinary(int i) { } @Override - public String getString(int i) { - return getUTF8String(i).toString(); - } - - @Override - public UnsafeRow getStruct(int i, int numFields) { - if (isNullAt(i)) { + public UnsafeRow getStruct(int ordinal, int numFields) { + if (isNullAt(ordinal)) { return null; } else { - assertIndexIsValid(i); - final long offsetAndSize = getLong(i); + assertIndexIsValid(ordinal); + final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); final UnsafeRow row = new UnsafeRow(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index f248b1f338acc..37f0f57e9e6d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.UTF8String /** @@ -29,35 +30,34 @@ abstract class InternalRow extends Serializable { def numFields: Int - def get(i: Int): Any + def get(ordinal: Int): Any - // TODO: Remove this. - def apply(i: Int): Any = get(i) + def getAs[T](ordinal: Int): T = get(ordinal).asInstanceOf[T] - def getAs[T](i: Int): T = get(i).asInstanceOf[T] + def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - def isNullAt(i: Int): Boolean = get(i) == null + def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal) - def getBoolean(i: Int): Boolean = getAs[Boolean](i) + def getByte(ordinal: Int): Byte = getAs[Byte](ordinal) - def getByte(i: Int): Byte = getAs[Byte](i) + def getShort(ordinal: Int): Short = getAs[Short](ordinal) - def getShort(i: Int): Short = getAs[Short](i) + def getInt(ordinal: Int): Int = getAs[Int](ordinal) - def getInt(i: Int): Int = getAs[Int](i) + def getLong(ordinal: Int): Long = getAs[Long](ordinal) - def getLong(i: Int): Long = getAs[Long](i) + def getFloat(ordinal: Int): Float = getAs[Float](ordinal) - def getFloat(i: Int): Float = getAs[Float](i) + def getDouble(ordinal: Int): Double = getAs[Double](ordinal) - def getDouble(i: Int): Double = getAs[Double](i) + def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal) - def getUTF8String(i: Int): UTF8String = getAs[UTF8String](i) + def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal) - def getBinary(i: Int): Array[Byte] = getAs[Array[Byte]](i) + def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal) - // This is only use for test - def getString(i: Int): String = getAs[UTF8String](i).toString + // This is only use for test and will throw a null pointer exception if the position is null. + def getString(ordinal: Int): String = getAs[UTF8String](ordinal).toString /** * Returns a struct from ordinal position. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 508882acbee5a..2a1e288cb8377 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -110,7 +110,7 @@ class CodeGenContext { case StringType => s"$row.getUTF8String($ordinal)" case BinaryType => s"$row.getBinary($ordinal)" case t: StructType => s"$row.getStruct($ordinal, ${t.size})" - case _ => s"($jt)$row.apply($ordinal)" + case _ => s"($jt)$row.get($ordinal)" } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index a2b0fad7b7a04..6caf8baf24a81 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -158,7 +158,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) - val actual = plan(inputRow).apply(0) + val actual = plan(inputRow).get(0) if (!actual.asInstanceOf[Double].isNaN) { fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 00374d1fa3ef1..7c63179af6470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -211,7 +211,7 @@ private[sql] class StringColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[UTF8String] + val value = row.getUTF8String(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += STRING.actualSize(row, ordinal) @@ -241,7 +241,7 @@ private[sql] class FixedDecimalColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Decimal] + val value = row.getDecimal(ordinal) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += FIXED_DECIMAL.defaultSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index ac42bde07c37d..c0ca52751b66c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -90,7 +90,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * boxing/unboxing costs whenever possible. */ def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to(toOrdinal) = from(fromOrdinal) + to(toOrdinal) = from.get(fromOrdinal) } /** @@ -329,11 +329,11 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { } override def getField(row: InternalRow, ordinal: Int): UTF8String = { - row(ordinal).asInstanceOf[UTF8String] + row.getUTF8String(ordinal) } override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { - to.update(toOrdinal, from(fromOrdinal)) + to.update(toOrdinal, from.getUTF8String(fromOrdinal)) } } @@ -347,7 +347,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { } override def getField(row: InternalRow, ordinal: Int): Int = { - row(ordinal).asInstanceOf[Int] + row.getInt(ordinal) } def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { @@ -365,7 +365,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { } override def getField(row: InternalRow, ordinal: Int): Long = { - row(ordinal).asInstanceOf[Long] + row.getLong(ordinal) } override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { @@ -388,7 +388,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) } override def getField(row: InternalRow, ordinal: Int): Decimal = { - row(ordinal).asInstanceOf[Decimal] + row.getDecimal(ordinal) } override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { @@ -427,7 +427,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - row(ordinal).asInstanceOf[Array[Byte]] + row.getBinary(ordinal) } } @@ -440,7 +440,7 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - SparkSqlSerializer.serialize(row(ordinal)) + SparkSqlSerializer.serialize(row.get(ordinal)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 5abc1259a19ab..6150df6930b32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -128,7 +128,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { while (from.hasRemaining) { columnType.extract(from, value, 0) - if (value(0) == currentValue(0)) { + if (value.get(0) == currentValue.get(0)) { currentRun += 1 } else { // Writes current run diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 83c4e8733f15f..6ee833c7b2c94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -278,7 +278,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[UTF8String](i).getBytes + val bytes = row.getUTF8String(i).getBytes out.writeInt(bytes.length) out.write(bytes) } @@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.apply(i).asInstanceOf[Decimal] + val value = row.getAs[Decimal](i) val javaBigDecimal = value.toJavaBigDecimal // First, write out the unscaled value. val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 7f452daef33c5..cdbe42381a7e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -170,6 +170,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { execution.PhysicalRDD(projections.map(_.toAttribute), unionedRows) } + // TODO: refactor this thing. It is very complicated because it does projection internally. + // We should just put a project on top of this. private def mergeWithPartitionValues( schema: StructType, requiredColumns: Array[String], @@ -187,13 +189,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (i != -1) { // If yes, gets column value from partition values. (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = partitionValues(i) + mutableRow(ordinal) = partitionValues.get(i) } } else { // Otherwise, inherits the value from scanned data. val i = nonPartitionColumns.indexOf(name) (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = dataRow(i) + mutableRow(ordinal) = dataRow.get(i) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index e6081cb05bc2d..1fdcc6a850602 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -136,7 +136,7 @@ package object debug { tupleCount += 1 var i = 0 while (i < numColumns) { - val value = currentRow(i) + val value = currentRow.get(i) if (value != null) { columnStats(i).elementTypes += HashSet(value.getClass.getName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 40bf03a3f1a62..970c40dc61a3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -129,7 +129,7 @@ object EvaluatePython { val values = new Array[Any](row.numFields) var i = 0 while (i < row.numFields) { - values(i) = toJava(row(i), struct.fields(i).dataType) + values(i) = toJava(row.get(i), struct.fields(i).dataType) i += 1 } new GenericInternalRowWithSchema(values, struct) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala index 46f0fac861282..7a6e86779b185 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -121,7 +121,7 @@ class MutableAggregationBuffer private[sql] ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - toScalaConverters(i)(underlyingBuffer(offsets(i))) + toScalaConverters(i)(underlyingBuffer.get(offsets(i))) } def update(i: Int, value: Any): Unit = { @@ -157,7 +157,7 @@ class InputAggregationBuffer private[sql] ( s"Could not access ${i}th value in this buffer because it only has $length values.") } // TODO: Use buffer schema to avoid using generic getter. - toScalaConverters(i)(underlyingInputBuffer(offsets(i))) + toScalaConverters(i)(underlyingInputBuffer.get(offsets(i))) } override def copy(): InputAggregationBuffer = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 8cab27d6e1c46..38bb1e3967642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -159,7 +159,7 @@ private[sql] case class ParquetTableScan( // Parquet will leave partitioning columns empty, so we fill them in here. var i = 0 - while (i < requestedPartitionOrdinals.size) { + while (i < requestedPartitionOrdinals.length) { row(requestedPartitionOrdinals(i)._2) = partitionRowValues(requestedPartitionOrdinals(i)._1) i += 1 @@ -179,12 +179,12 @@ private[sql] case class ParquetTableScan( var i = 0 while (i < row.numFields) { - mutableRow(i) = row(i) + mutableRow(i) = row.get(i) i += 1 } // Parquet will leave partitioning columns empty, so we fill them in here. i = 0 - while (i < requestedPartitionOrdinals.size) { + while (i < requestedPartitionOrdinals.length) { mutableRow(requestedPartitionOrdinals(i)._2) = partitionRowValues(requestedPartitionOrdinals(i)._1) i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index c7c58e69d42ef..2c23d4e8a8146 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -217,9 +217,9 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.startMessage() while(index < attributesSize) { // null values indicate optional fields but we do not check currently - if (record(index) != null) { + if (!record.isNullAt(index)) { writer.startField(attributes(index).name, index) - writeValue(attributes(index).dataType, record(index)) + writeValue(attributes(index).dataType, record.get(index)) writer.endField(attributes(index).name, index) } index = index + 1 @@ -277,10 +277,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo val fields = schema.fields.toArray writer.startGroup() var i = 0 - while(i < fields.size) { - if (struct(i) != null) { + while(i < fields.length) { + if (!struct.isNullAt(i)) { writer.startField(fields(i).name, i) - writeValue(fields(i).dataType, struct(i)) + writeValue(fields(i).dataType, struct.get(i)) writer.endField(fields(i).name, i) } i = i + 1 @@ -387,7 +387,7 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { writer.startMessage() while(index < attributesSize) { // null values indicate optional fields but we do not check currently - if (record(index) != null && record(index) != Nil) { + if (!record.isNullAt(index) && !record.isNullAt(index)) { writer.startField(attributes(index).name, index) consumeType(attributes(index).dataType, record, index) writer.endField(attributes(index).name, index) @@ -410,15 +410,15 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case TimestampType => writeTimestamp(record.getLong(index)) case FloatType => writer.addFloat(record.getFloat(index)) case DoubleType => writer.addDouble(record.getDouble(index)) - case StringType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) + case StringType => + writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) + case BinaryType => + writer.addBinary(Binary.fromByteArray(record.getBinary(index))) case d: DecimalType => if (d.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") } - writeDecimal(record(index).asInstanceOf[Decimal], d.precision) + writeDecimal(record.getDecimal(index), d.precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 0e5c5abff85f6..c6804e84827c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -39,14 +39,14 @@ class RowSuite extends SparkFunSuite { assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) - assert(expected(3) === actual1(3)) + assert(expected.get(3) === actual1.get(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) - assert(expected(3) === actual2(3)) + assert(expected.get(3) === actual2.get(3)) } test("SpecificMutableRow.update with null") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 3333fee6711c0..31e7b0e72e510 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -58,15 +58,15 @@ class ColumnStatsSuite extends SparkFunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_(0).asInstanceOf[T#InternalType]) + val values = rows.take(10).map(_.get(0).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) - assertResult(10, "Wrong null count")(stats(2)) - assertResult(20, "Wrong row count")(stats(3)) - assertResult(stats(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1)) + assertResult(10, "Wrong null count")(stats.get(2)) + assertResult(20, "Wrong row count")(stats.get(3)) + assertResult(stats.get(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 9eaa769846088..d421f4d8d091e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -75,7 +75,7 @@ class NullableColumnAccessorSuite extends SparkFunSuite { (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row(0) === randomRow(0)) + assert(row.get(0) === randomRow.get(0)) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 17e9ae464bcc0..cd8bf75ff1752 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -98,7 +98,7 @@ class NullableColumnBuilderSuite extends SparkFunSuite { columnType.extract(buffer) } - assert(actual === randomRow(0), "Extracted value didn't equal to the original one") + assert(actual === randomRow.get(0), "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index f606e2133bedc..33092c83a1a1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -33,7 +33,7 @@ class BooleanBitSetSuite extends SparkFunSuite { val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_(0)) + val values = rows.map(_.get(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 592cfa0ee8380..16977ce30cfff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -497,7 +497,7 @@ private[hive] trait HiveInspectors { x.setStructFieldData( result, fieldRefs.get(i), - wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector)) i += 1 } @@ -508,7 +508,7 @@ private[hive] trait HiveInspectors { val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { - result.add(wrap(row(i), fieldRefs.get(i).getFieldObjectInspector)) + result.add(wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector)) i += 1 } @@ -536,7 +536,7 @@ private[hive] trait HiveInspectors { cache: Array[AnyRef]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row.get(i), inspectors(i)) i += 1 } cache diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 34b629403e128..f0e0ca05a8aad 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -102,7 +102,7 @@ case class InsertIntoHiveTable( iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i)) i += 1 } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 10623dc820316..58445095ad74f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -122,7 +122,7 @@ private[orc] class OrcOutputWriter( override def writeInternal(row: InternalRow): Unit = { var i = 0 while (i < row.numFields) { - reusableOutputBuffer(i) = wrappers(i)(row(i)) + reusableOutputBuffer(i) = wrappers(i)(row.get(i)) i += 1 } From 41a7cdf85de2d583d8b8759941a9d6c6e98cae4d Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Sat, 25 Jul 2015 22:56:25 -0700 Subject: [PATCH 066/219] [SPARK-8881] [SPARK-9260] Fix algorithm for scheduling executors on workers Current scheduling algorithm allocates one core at a time and in doing so ends up ignoring spark.executor.cores. As a result, when spark.cores.max/spark.executor.cores (i.e, num_executors) < num_workers, executors are not launched and the app hangs. This PR fixes and refactors the scheduling algorithm. andrewor14 Author: Nishkam Ravi Author: nishkamravi2 Closes #7274 from nishkamravi2/master_scheduler and squashes the following commits: b998097 [nishkamravi2] Update Master.scala da0f491 [Nishkam Ravi] Update Master.scala 79084e8 [Nishkam Ravi] Update Master.scala 1daf25f [Nishkam Ravi] Update Master.scala f279cdf [Nishkam Ravi] Update Master.scala adec84b [Nishkam Ravi] Update Master.scala a06da76 [nishkamravi2] Update Master.scala 40c8f9f [nishkamravi2] Update Master.scala (to trigger retest) c11c689 [nishkamravi2] Update EventLoggingListenerSuite.scala 5d6a19c [nishkamravi2] Update Master.scala (for the purpose of issuing a retest) 2d6371c [Nishkam Ravi] Update Master.scala 66362d5 [nishkamravi2] Update Master.scala ee7cf0e [Nishkam Ravi] Improved scheduling algorithm for executors --- .../apache/spark/deploy/master/Master.scala | 112 ++++++++++++------ 1 file changed, 75 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 4615febf17d24..029f94d1020be 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -541,6 +541,7 @@ private[master] class Master( /** * Schedule executors to be launched on the workers. + * Returns an array containing number of cores assigned to each worker. * * There are two modes of launching executors. The first attempts to spread out an application's * executors on as many workers as possible, while the second does the opposite (i.e. launch them @@ -551,39 +552,73 @@ private[master] class Master( * multiple executors from the same application may be launched on the same worker if the worker * has enough cores and memory. Otherwise, each executor grabs all the cores available on the * worker by default, in which case only one executor may be launched on each worker. + * + * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core + * at a time). Consider the following example: cluster has 4 workers with 16 cores each. + * User requests 3 executors (spark.cores.max = 48, spark.executor.cores = 16). If 1 core is + * allocated at a time, 12 cores from each worker would be assigned to each executor. + * Since 12 < 16, no executors would launch [SPARK-8881]. */ - private def startExecutorsOnWorkers(): Unit = { - // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app - // in the queue, then the second app, etc. - if (spreadOutApps) { - // Try to spread out each app among all the workers, until it has all its cores - for (app <- waitingApps if app.coresLeft > 0) { - val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) - .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && - worker.coresFree >= app.desc.coresPerExecutor.getOrElse(1)) - .sortBy(_.coresFree).reverse - val numUsable = usableWorkers.length - val assigned = new Array[Int](numUsable) // Number of cores to give on each node - var toAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) - var pos = 0 - while (toAssign > 0) { - if (usableWorkers(pos).coresFree - assigned(pos) > 0) { - toAssign -= 1 - assigned(pos) += 1 + private[master] def scheduleExecutorsOnWorkers( + app: ApplicationInfo, + usableWorkers: Array[WorkerInfo], + spreadOutApps: Boolean): Array[Int] = { + // If the number of cores per executor is not specified, then we can just schedule + // 1 core at a time since we expect a single executor to be launched on each worker + val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(1) + val memoryPerExecutor = app.desc.memoryPerExecutorMB + val numUsable = usableWorkers.length + val assignedCores = new Array[Int](numUsable) // Number of cores to give to each worker + val assignedMemory = new Array[Int](numUsable) // Amount of memory to give to each worker + var coresToAssign = math.min(app.coresLeft, usableWorkers.map(_.coresFree).sum) + var freeWorkers = (0 until numUsable).toIndexedSeq + + def canLaunchExecutor(pos: Int): Boolean = { + usableWorkers(pos).coresFree - assignedCores(pos) >= coresPerExecutor && + usableWorkers(pos).memoryFree - assignedMemory(pos) >= memoryPerExecutor + } + + while (coresToAssign >= coresPerExecutor && freeWorkers.nonEmpty) { + freeWorkers = freeWorkers.filter(canLaunchExecutor) + freeWorkers.foreach { pos => + var keepScheduling = true + while (keepScheduling && canLaunchExecutor(pos) && coresToAssign >= coresPerExecutor) { + coresToAssign -= coresPerExecutor + assignedCores(pos) += coresPerExecutor + assignedMemory(pos) += memoryPerExecutor + + // Spreading out an application means spreading out its executors across as + // many workers as possible. If we are not spreading out, then we should keep + // scheduling executors on this worker until we use all of its resources. + // Otherwise, just move on to the next worker. + if (spreadOutApps) { + keepScheduling = false } - pos = (pos + 1) % numUsable - } - // Now that we've decided how many cores to give on each node, let's actually give them - for (pos <- 0 until numUsable if assigned(pos) > 0) { - allocateWorkerResourceToExecutors(app, assigned(pos), usableWorkers(pos)) } } - } else { - // Pack each app into as few workers as possible until we've assigned all its cores - for (worker <- workers if worker.coresFree > 0 && worker.state == WorkerState.ALIVE) { - for (app <- waitingApps if app.coresLeft > 0) { - allocateWorkerResourceToExecutors(app, app.coresLeft, worker) - } + } + assignedCores + } + + /** + * Schedule and launch executors on workers + */ + private def startExecutorsOnWorkers(): Unit = { + // Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app + // in the queue, then the second app, etc. + for (app <- waitingApps if app.coresLeft > 0) { + val coresPerExecutor: Option[Int] = app.desc.coresPerExecutor + // Filter out workers that don't have enough resources to launch an executor + val usableWorkers = workers.toArray.filter(_.state == WorkerState.ALIVE) + .filter(worker => worker.memoryFree >= app.desc.memoryPerExecutorMB && + worker.coresFree >= coresPerExecutor.getOrElse(1)) + .sortBy(_.coresFree).reverse + val assignedCores = scheduleExecutorsOnWorkers(app, usableWorkers, spreadOutApps) + + // Now that we've decided how many cores to allocate on each worker, let's allocate them + for (pos <- 0 until usableWorkers.length if assignedCores(pos) > 0) { + allocateWorkerResourceToExecutors( + app, assignedCores(pos), coresPerExecutor, usableWorkers(pos)) } } } @@ -591,19 +626,22 @@ private[master] class Master( /** * Allocate a worker's resources to one or more executors. * @param app the info of the application which the executors belong to - * @param coresToAllocate cores on this worker to be allocated to this application + * @param assignedCores number of cores on this worker for this application + * @param coresPerExecutor number of cores per executor * @param worker the worker info */ private def allocateWorkerResourceToExecutors( app: ApplicationInfo, - coresToAllocate: Int, + assignedCores: Int, + coresPerExecutor: Option[Int], worker: WorkerInfo): Unit = { - val memoryPerExecutor = app.desc.memoryPerExecutorMB - val coresPerExecutor = app.desc.coresPerExecutor.getOrElse(coresToAllocate) - var coresLeft = coresToAllocate - while (coresLeft >= coresPerExecutor && worker.memoryFree >= memoryPerExecutor) { - val exec = app.addExecutor(worker, coresPerExecutor) - coresLeft -= coresPerExecutor + // If the number of cores per executor is specified, we divide the cores assigned + // to this worker evenly among the executors with no remainder. + // Otherwise, we launch a single executor that grabs all the assignedCores on this worker. + val numExecutors = coresPerExecutor.map { assignedCores / _ }.getOrElse(1) + val coresToAssign = coresPerExecutor.getOrElse(assignedCores) + for (i <- 1 to numExecutors) { + val exec = app.addExecutor(worker, coresToAssign) launchExecutor(worker, exec) app.state = ApplicationState.RUNNING } From 4a01bfc2a2e664186028ea32095d32d29c9f9e38 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 25 Jul 2015 23:52:37 -0700 Subject: [PATCH 067/219] [SPARK-9350][SQL] Introduce an InternalRow generic getter that requires a DataType Currently UnsafeRow cannot support a generic getter. However, if the data type is known, we can support a generic getter. Author: Reynold Xin Closes #7666 from rxin/generic-getter-with-datatype and squashes the following commits: ee2874c [Reynold Xin] Add a default implementation for getStruct. 1e109a0 [Reynold Xin] [SPARK-9350][SQL] Introduce an InternalRow generic getter that requires a DataType. 033ee88 [Reynold Xin] Removed getAs in non test code. --- .../apache/spark/mllib/linalg/Matrices.scala | 8 +++-- .../apache/spark/mllib/linalg/Vectors.scala | 9 ++++-- .../sql/catalyst/expressions/UnsafeRow.java | 5 --- .../sql/catalyst/CatalystTypeConverters.scala | 16 ++++++---- .../spark/sql/catalyst/InternalRow.scala | 32 +++++++++++-------- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 5 +-- .../sql/catalyst/expressions/Projection.scala | 8 +++++ .../expressions/SpecificMutableRow.scala | 10 +++--- .../sql/catalyst/expressions/aggregates.scala | 2 +- .../expressions/complexTypeCreator.scala | 2 +- .../expressions/complexTypeExtractors.scala | 4 +-- .../spark/sql/catalyst/expressions/rows.scala | 8 +++++ .../expressions/ExpressionEvalHelper.scala | 7 ++-- .../expressions/MathFunctionsSuite.scala | 2 +- .../expressions/UnsafeRowConverterSuite.scala | 2 +- .../sql/execution/SparkSqlSerializer2.scala | 4 +-- .../spark/sql/execution/basicOperators.scala | 4 +-- .../datasources/DataSourceStrategy.scala | 4 +-- .../spark/sql/execution/debug/package.scala | 2 +- .../spark/sql/execution/pythonUDFs.scala | 2 +- .../sql/execution/stat/FrequentItems.scala | 8 ++--- .../sql/expressions/aggregate/udaf.scala | 10 ++++-- .../sql/parquet/ParquetTableOperations.scala | 2 +- .../sql/parquet/ParquetTableSupport.scala | 4 +-- .../scala/org/apache/spark/sql/RowSuite.scala | 11 ++++--- .../hive/execution/InsertIntoHiveTable.scala | 4 ++- .../spark/sql/hive/orc/OrcRelation.scala | 2 +- 28 files changed, 105 insertions(+), 74 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index b6e2c30fbf104..d82ba2456df1a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -179,12 +179,14 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getAs[Iterable[Double]](5).toArray + val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = row.getAs[Iterable[Int]](3).toArray - val rowIndices = row.getAs[Iterable[Int]](4).toArray + val colPtrs = + row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray + val rowIndices = + row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index c884aad08889f..0cb28d78bec05 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -209,11 +209,14 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = row.getAs[Iterable[Int]](2).toArray - val values = row.getAs[Iterable[Double]](3).toArray + val indices = + row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray + val values = + row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray new SparseVector(size, indices, values) case 1 => - val values = row.getAs[Iterable[Double]](3).toArray + val values = + row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray new DenseVector(values) } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 9be9089493335..87e5a89c19658 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -235,11 +235,6 @@ public Object get(int ordinal) { throw new UnsupportedOperationException(); } - @Override - public T getAs(int ordinal) { - throw new UnsupportedOperationException(); - } - @Override public boolean isNullAt(int ordinal) { assertIndexIsValid(ordinal); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 7416ddbaef3fc..d1d89a1f48329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -77,7 +77,7 @@ object CatalystTypeConverters { case LongType => LongConverter case FloatType => FloatConverter case DoubleType => DoubleConverter - case _ => IdentityConverter + case dataType: DataType => IdentityConverter(dataType) } converter.asInstanceOf[CatalystTypeConverter[Any, Any, Any]] } @@ -137,17 +137,19 @@ object CatalystTypeConverters { protected def toScalaImpl(row: InternalRow, column: Int): ScalaOutputType } - private object IdentityConverter extends CatalystTypeConverter[Any, Any, Any] { + private case class IdentityConverter(dataType: DataType) + extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = scalaValue override def toScala(catalystValue: Any): Any = catalystValue - override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column) + override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType) } private case class UDTConverter( udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] { override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue) override def toScala(catalystValue: Any): Any = udt.deserialize(catalystValue) - override def toScalaImpl(row: InternalRow, column: Int): Any = toScala(row.get(column)) + override def toScalaImpl(row: InternalRow, column: Int): Any = + toScala(row.get(column, udt.sqlType)) } /** Converter for arrays, sequences, and Java iterables. */ @@ -184,7 +186,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row.get(column).asInstanceOf[Seq[Any]]) + toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]]) } private case class MapConverter( @@ -227,7 +229,7 @@ object CatalystTypeConverters { } override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = - toScala(row.get(column).asInstanceOf[Map[Any, Any]]) + toScala(row.get(column, MapType(keyType, valueType)).asInstanceOf[Map[Any, Any]]) } private case class StructConverter( @@ -311,7 +313,7 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.get(column).asInstanceOf[Decimal].toJavaBigDecimal + row.getDecimal(column).toJavaBigDecimal } private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 37f0f57e9e6d3..385d9671386dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -32,32 +32,36 @@ abstract class InternalRow extends Serializable { def get(ordinal: Int): Any - def getAs[T](ordinal: Int): T = get(ordinal).asInstanceOf[T] + def genericGet(ordinal: Int): Any = get(ordinal, null) + + def get(ordinal: Int, dataType: DataType): Any = get(ordinal) + + def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal) + def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) - def getByte(ordinal: Int): Byte = getAs[Byte](ordinal) + def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) - def getShort(ordinal: Int): Short = getAs[Short](ordinal) + def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) - def getInt(ordinal: Int): Int = getAs[Int](ordinal) + def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) - def getLong(ordinal: Int): Long = getAs[Long](ordinal) + def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) - def getFloat(ordinal: Int): Float = getAs[Float](ordinal) + def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) - def getDouble(ordinal: Int): Double = getAs[Double](ordinal) + def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) - def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal) + def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) - def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal) + def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal) + def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) // This is only use for test and will throw a null pointer exception if the position is null. - def getString(ordinal: Int): String = getAs[UTF8String](ordinal).toString + def getString(ordinal: Int): String = getUTF8String(ordinal).toString /** * Returns a struct from ordinal position. @@ -65,7 +69,7 @@ abstract class InternalRow extends Serializable { * @param ordinal position to get the struct from. * @param numFields number of fields the struct type has */ - def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal) + def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null) override def toString: String = s"[${this.mkString(",")}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 1f7adcd36ec14..6b5c450e3fb0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -49,7 +49,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) case t: StructType => input.getStruct(ordinal, t.size) - case _ => input.get(ordinal) + case dataType => input.get(ordinal, dataType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 47ad3e089e4c7..e5b83cd31bf0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -375,7 +375,7 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def castStruct(from: StructType, to: StructType): Any => Any = { - val casts = from.fields.zip(to.fields).map { + val castFuncs: Array[(Any) => Any] = from.fields.zip(to.fields).map { case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? @@ -383,7 +383,8 @@ case class Cast(child: Expression, dataType: DataType) buildCast[InternalRow](_, row => { var i = 0 while (i < row.numFields) { - newRow.update(i, if (row.isNullAt(i)) null else casts(i)(row.get(i))) + newRow.update(i, + if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType))) i += 1 } newRow.copy() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index c1ed9cf7ed6a0..cc89d74146b34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -225,6 +225,14 @@ class JoinedRow extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + override def getStruct(i: Int, numFields: Int): InternalRow = { + if (i < row1.numFields) { + row1.getStruct(i, numFields) + } else { + row2.getStruct(i - row1.numFields, numFields) + } + } + override def copy(): InternalRow = { val totalSize = row1.numFields + row2.numFields val copiedValues = new Array[Any](totalSize) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 4b4833bd06a3b..5953a093dc684 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -221,6 +221,10 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def get(i: Int): Any = values(i).boxed + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values(ordinal).boxed.asInstanceOf[InternalRow] + } + override def isNullAt(i: Int): Boolean = values(i).isNull override def copy(): InternalRow = { @@ -245,8 +249,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String.fromString(value)) - override def getString(ordinal: Int): String = get(ordinal).toString - override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] currentValue.isNull = false @@ -316,8 +318,4 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def getByte(i: Int): Byte = { values(i).asInstanceOf[MutableByte].value } - - override def getAs[T](i: Int): T = { - values(i).boxed.asInstanceOf[T] - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 62b6cc834c9c9..42343d4d8d79c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -685,7 +685,7 @@ case class CombineSetsAndSumFunction( null } else { Cast(Literal( - casted.iterator.map(f => f.get(0)).reduceLeft( + casted.iterator.map(f => f.genericGet(0)).reduceLeft( base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), base.dataType).eval(null) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 20b1eaab8e303..119168fa59f15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index c91122cda2a41..6331a9eb603ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -110,7 +110,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow].get(ordinal) + input.asInstanceOf[InternalRow].get(ordinal, field.dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { @@ -142,7 +142,7 @@ case class GetArrayStructFields( protected override def nullSafeEval(input: Any): Any = { input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row.get(ordinal) + if (row == null) null else row.get(ordinal, field.dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 53779dd4049d1..daeabe8e90f1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -101,6 +101,10 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) extends Internal override def get(i: Int): Any = values(i) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values(ordinal).asInstanceOf[InternalRow] + } + override def copy(): InternalRow = this } @@ -128,6 +132,10 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow { override def get(i: Int): Any = values(i) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = { + values(ordinal).asInstanceOf[InternalRow] + } + override def setNullAt(i: Int): Unit = { values(i) = null} override def update(i: Int, value: Any): Unit = { values(i) = value } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 852a8b235f127..8b0f90cf3a623 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -113,7 +113,7 @@ trait ExpressionEvalHelper { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) - val actual = plan(inputRow).get(0) + val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") @@ -194,13 +194,14 @@ trait ExpressionEvalHelper { var plan = generateProject( GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - var actual = plan(inputRow).get(0) + var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - actual = FromUnsafeProjection(expression.dataType :: Nil)(plan(inputRow)).get(0) + actual = FromUnsafeProjection(expression.dataType :: Nil)( + plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 6caf8baf24a81..21459a7c69838 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -158,7 +158,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) - val actual = plan(inputRow).get(0) + val actual = plan(inputRow).get(0, expression.dataType) if (!actual.asInstanceOf[Double].isNaN) { fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: NaN") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 4606bcb57311d..2834b54e8fb2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -183,7 +183,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) - assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 6ee833c7b2c94..c808442a4849b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -288,7 +288,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[Array[Byte]](i) + val bytes = row.getBinary(i) out.writeInt(bytes.length) out.write(bytes) } @@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.getAs[Decimal](i) + val value = row.getDecimal(i) val javaBigDecimal = value.toJavaBigDecimal // First, write out the unscaled value. val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fdd7ad59aba50..fe429d862a0a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.types.StructType import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.apache.spark.util.{CompletionIterator, MutablePair} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index cdbe42381a7e4..6b91e51ca52fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -189,13 +189,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { if (i != -1) { // If yes, gets column value from partition values. (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = partitionValues.get(i) + mutableRow(ordinal) = partitionValues.genericGet(i) } } else { // Otherwise, inherits the value from scanned data. val i = nonPartitionColumns.indexOf(name) (mutableRow: MutableRow, dataRow: InternalRow, ordinal: Int) => { - mutableRow(ordinal) = dataRow.get(i) + mutableRow(ordinal) = dataRow.genericGet(i) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 1fdcc6a850602..aeeb0e45270dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -136,7 +136,7 @@ package object debug { tupleCount += 1 var i = 0 while (i < numColumns) { - val value = currentRow.get(i) + val value = currentRow.get(i, output(i).dataType) if (value != null) { columnStats(i).elementTypes += HashSet(value.getClass.getName) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 970c40dc61a3c..ec084a299649e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -129,7 +129,7 @@ object EvaluatePython { val values = new Array[Any](row.numFields) var i = 0 while (i < row.numFields) { - values(i) = toJava(row.get(i), struct.fields(i).dataType) + values(i) = toJava(row.get(i, struct.fields(i).dataType), struct.fields(i).dataType) i += 1 } new GenericInternalRowWithSchema(values, struct) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index ec5c6950f37ad..78da2840dad69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{ArrayType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, ArrayType, StructField, StructType} import org.apache.spark.sql.{Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -85,17 +85,17 @@ private[sql] object FrequentItems extends Logging { val sizeOfMap = (1 / support).toInt val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap)) val originalSchema = df.schema - val colInfo = cols.map { name => + val colInfo: Array[(String, DataType)] = cols.map { name => val index = originalSchema.fieldIndex(name) (name, originalSchema.fields(index).dataType) - } + }.toArray val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { val thisMap = counts(i) - val key = row.get(i) + val key = row.get(i, colInfo(i)._2) thisMap.add(key, 1L) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala index 7a6e86779b185..4ada9eca7a035 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -110,6 +110,7 @@ private[sql] abstract class AggregationBuffer( * A Mutable [[Row]] representing an mutable aggregation buffer. */ class MutableAggregationBuffer private[sql] ( + schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, @@ -121,7 +122,7 @@ class MutableAggregationBuffer private[sql] ( throw new IllegalArgumentException( s"Could not access ${i}th value in this buffer because it only has $length values.") } - toScalaConverters(i)(underlyingBuffer.get(offsets(i))) + toScalaConverters(i)(underlyingBuffer.get(offsets(i), schema(i).dataType)) } def update(i: Int, value: Any): Unit = { @@ -134,6 +135,7 @@ class MutableAggregationBuffer private[sql] ( override def copy(): MutableAggregationBuffer = { new MutableAggregationBuffer( + schema, toCatalystConverters, toScalaConverters, bufferOffset, @@ -145,6 +147,7 @@ class MutableAggregationBuffer private[sql] ( * A [[Row]] representing an immutable aggregation buffer. */ class InputAggregationBuffer private[sql] ( + schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, @@ -157,11 +160,12 @@ class InputAggregationBuffer private[sql] ( s"Could not access ${i}th value in this buffer because it only has $length values.") } // TODO: Use buffer schema to avoid using generic getter. - toScalaConverters(i)(underlyingInputBuffer.get(offsets(i))) + toScalaConverters(i)(underlyingInputBuffer.get(offsets(i), schema(i).dataType)) } override def copy(): InputAggregationBuffer = { new InputAggregationBuffer( + schema, toCatalystConverters, toScalaConverters, bufferOffset, @@ -233,6 +237,7 @@ case class ScalaUDAF( lazy val inputAggregateBuffer: InputAggregationBuffer = new InputAggregationBuffer( + bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, bufferOffset, @@ -240,6 +245,7 @@ case class ScalaUDAF( lazy val mutableAggregateBuffer: MutableAggregationBuffer = new MutableAggregationBuffer( + bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, bufferOffset, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 38bb1e3967642..75cbbde4f1512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -179,7 +179,7 @@ private[sql] case class ParquetTableScan( var i = 0 while (i < row.numFields) { - mutableRow(i) = row.get(i) + mutableRow(i) = row.genericGet(i) i += 1 } // Parquet will leave partitioning columns empty, so we fill them in here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 2c23d4e8a8146..7b6a7f65d69db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -219,7 +219,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo // null values indicate optional fields but we do not check currently if (!record.isNullAt(index)) { writer.startField(attributes(index).name, index) - writeValue(attributes(index).dataType, record.get(index)) + writeValue(attributes(index).dataType, record.get(index, attributes(index).dataType)) writer.endField(attributes(index).name, index) } index = index + 1 @@ -280,7 +280,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo while(i < fields.length) { if (!struct.isNullAt(i)) { writer.startField(fields(i).name, i) - writeValue(fields(i).dataType, struct.get(i)) + writeValue(fields(i).dataType, struct.get(i, fields(i).dataType)) writer.endField(fields(i).name, i) } i = i + 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index c6804e84827c0..01b7c21e84159 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -30,23 +30,24 @@ class RowSuite extends SparkFunSuite { test("create row") { val expected = new GenericMutableRow(4) - expected.update(0, 2147483647) + expected.setInt(0, 2147483647) expected.setString(1, "this is a string") - expected.update(2, false) - expected.update(3, null) + expected.setBoolean(2, false) + expected.setNullAt(3) + val actual1 = Row(2147483647, "this is a string", false, null) assert(expected.numFields === actual1.size) assert(expected.getInt(0) === actual1.getInt(0)) assert(expected.getString(1) === actual1.getString(1)) assert(expected.getBoolean(2) === actual1.getBoolean(2)) - assert(expected.get(3) === actual1.get(3)) + assert(expected.isNullAt(3) === actual1.isNullAt(3)) val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) assert(expected.numFields === actual2.size) assert(expected.getInt(0) === actual2.getInt(0)) assert(expected.getString(1) === actual2.getString(1)) assert(expected.getBoolean(2) === actual2.getBoolean(2)) - assert(expected.get(3) === actual2.get(3)) + assert(expected.isNullAt(3) === actual2.isNullAt(3)) } test("SpecificMutableRow.update with null") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f0e0ca05a8aad..e4944caeff924 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive._ +import org.apache.spark.sql.types.DataType import org.apache.spark.{SparkException, TaskContext} import scala.collection.JavaConversions._ @@ -96,13 +97,14 @@ case class InsertIntoHiveTable( val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray val wrappers = fieldOIs.map(wrapperFor) val outputData = new Array[Any](fieldOIs.length) + val dataTypes: Array[DataType] = child.output.map(_.dataType).toArray writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) iterator.foreach { row => var i = 0 while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i)) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) i += 1 } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 58445095ad74f..924f4d37ce21f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -122,7 +122,7 @@ private[orc] class OrcOutputWriter( override def writeInternal(row: InternalRow): Unit = { var i = 0 while (i < row.numFields) { - reusableOutputBuffer(i) = wrappers(i)(row.get(i)) + reusableOutputBuffer(i) = wrappers(i)(row.get(i, dataSchema(i).dataType)) i += 1 } From b79bf1df6238c087c3ec524344f1fc179719c5de Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Sun, 26 Jul 2015 14:02:20 +0100 Subject: [PATCH 068/219] [SPARK-9337] [MLLIB] Add an ut for Word2Vec to verify the empty vocabulary check jira: https://issues.apache.org/jira/browse/SPARK-9337 Word2Vec should throw exception when vocabulary is empty Author: Yuhao Yang Closes #7660 from hhbyyh/ut4Word2vec and squashes the following commits: 17a18cb [Yuhao Yang] add ut for word2vec --- .../org/apache/spark/mllib/feature/Word2VecSuite.scala | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 4cc8d1129b858..a864eec460f2b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -45,6 +45,16 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(newModel.getVectors.mapValues(_.toSeq) === word2VecMap.mapValues(_.toSeq)) } + test("Word2Vec throws exception when vocabulary is empty") { + intercept[IllegalArgumentException] { + val sentence = "a b c" + val localDoc = Seq(sentence, sentence) + val doc = sc.parallelize(localDoc) + .map(line => line.split(" ").toSeq) + new Word2Vec().setMinCount(10).fit(doc) + } + } + test("Word2VecModel") { val num = 2 val word2VecMap = Map( From 6c400b4f39be3fb5f473b8d2db11d239ea8ddf42 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Jul 2015 10:27:39 -0700 Subject: [PATCH 069/219] [SPARK-9354][SQL] Remove InternalRow.get generic getter call in Hive integration code. Replaced them with get(ordinal, datatype) so we can use UnsafeRow here. I passed the data types throughout. Author: Reynold Xin Closes #7669 from rxin/row-generic-getter-hive and squashes the following commits: 3467d8e [Reynold Xin] [SPARK-9354][SQL] Remove Internal.get generic getter call in Hive integration code. --- .../spark/sql/hive/HiveInspectors.scala | 43 ++++++----- .../org/apache/spark/sql/hive/hiveUDFs.scala | 74 ++++++++++++------- .../spark/sql/hive/HiveInspectorSuite.scala | 53 +++++++------ 3 files changed, 102 insertions(+), 68 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 16977ce30cfff..f467500259c91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -46,7 +46,7 @@ import scala.collection.JavaConversions._ * long / scala.Long * short / scala.Short * byte / scala.Byte - * org.apache.spark.sql.types.Decimal + * [[org.apache.spark.sql.types.Decimal]] * Array[Byte] * java.sql.Date * java.sql.Timestamp @@ -54,7 +54,7 @@ import scala.collection.JavaConversions._ * Map: scala.collection.immutable.Map * List: scala.collection.immutable.Seq * Struct: - * org.apache.spark.sql.catalyst.expression.Row + * [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -454,7 +454,7 @@ private[hive] trait HiveInspectors { * * NOTICE: the complex data type requires recursive wrapping. */ - def wrap(a: Any, oi: ObjectInspector): AnyRef = oi match { + def wrap(a: Any, oi: ObjectInspector, dataType: DataType): AnyRef = oi match { case x: ConstantObjectInspector => x.getWritableConstantValue case _ if a == null => null case x: PrimitiveObjectInspector => x match { @@ -488,43 +488,50 @@ private[hive] trait HiveInspectors { } case x: SettableStructObjectInspector => val fieldRefs = x.getAllStructFieldRefs + val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] // 1. create the pojo (most likely) object val result = x.create() var i = 0 while (i < fieldRefs.length) { // 2. set the property for the pojo + val tpe = structType(i).dataType x.setStructFieldData( result, fieldRefs.get(i), - wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector)) + wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: StructObjectInspector => val fieldRefs = x.getAllStructFieldRefs + val structType = dataType.asInstanceOf[StructType] val row = a.asInstanceOf[InternalRow] val result = new java.util.ArrayList[AnyRef](fieldRefs.length) var i = 0 while (i < fieldRefs.length) { - result.add(wrap(row.get(i), fieldRefs.get(i).getFieldObjectInspector)) + val tpe = structType(i).dataType + result.add(wrap(row.get(i, tpe), fieldRefs.get(i).getFieldObjectInspector, tpe)) i += 1 } result case x: ListObjectInspector => val list = new java.util.ArrayList[Object] + val tpe = dataType.asInstanceOf[ArrayType].elementType a.asInstanceOf[Seq[_]].foreach { - v => list.add(wrap(v, x.getListElementObjectInspector)) + v => list.add(wrap(v, x.getListElementObjectInspector, tpe)) } list case x: MapObjectInspector => + val keyType = dataType.asInstanceOf[MapType].keyType + val valueType = dataType.asInstanceOf[MapType].valueType // Some UDFs seem to assume we pass in a HashMap. val hashMap = new java.util.HashMap[AnyRef, AnyRef]() - hashMap.putAll(a.asInstanceOf[Map[_, _]].map { - case (k, v) => - wrap(k, x.getMapKeyObjectInspector) -> wrap(v, x.getMapValueObjectInspector) + hashMap.putAll(a.asInstanceOf[Map[_, _]].map { case (k, v) => + wrap(k, x.getMapKeyObjectInspector, keyType) -> + wrap(v, x.getMapValueObjectInspector, valueType) }) hashMap @@ -533,22 +540,24 @@ private[hive] trait HiveInspectors { def wrap( row: InternalRow, inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row.get(i), inspectors(i)) + cache(i) = wrap(row.get(i, dataTypes(i)), inspectors(i), dataTypes(i)) i += 1 } cache } def wrap( - row: Seq[Any], - inspectors: Seq[ObjectInspector], - cache: Array[AnyRef]): Array[AnyRef] = { + row: Seq[Any], + inspectors: Seq[ObjectInspector], + cache: Array[AnyRef], + dataTypes: Array[DataType]): Array[AnyRef] = { var i = 0 while (i < inspectors.length) { - cache(i) = wrap(row(i), inspectors(i)) + cache(i) = wrap(row(i), inspectors(i), dataTypes(i)) i += 1 } cache @@ -625,7 +634,7 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector))) + value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => @@ -636,7 +645,7 @@ private[hive] trait HiveInspectors { } else { val map = new java.util.HashMap[Object, Object]() value.asInstanceOf[Map[_, _]].foreach (entry => { - map.put(wrap(entry._1, keyOI), wrap(entry._2, valueOI)) + map.put(wrap(entry._1, keyOI, keyType), wrap(entry._2, valueOI, valueType)) }) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, map) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 3259b50acc765..54bf6bd67ff84 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -83,24 +83,22 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { - type UDFType = UDF - override def deterministic: Boolean = isUDFDeterministic override def nullable: Boolean = true @transient - lazy val function = funcWrapper.createFunction[UDFType]() + lazy val function = funcWrapper.createFunction[UDF]() @transient - protected lazy val method = + private lazy val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) @transient - protected lazy val arguments = children.map(toInspector).toArray + private lazy val arguments = children.map(toInspector).toArray @transient - protected lazy val isUDFDeterministic = { + private lazy val isUDFDeterministic = { val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } @@ -109,7 +107,7 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre // Create parameter converters @transient - protected lazy val conversionHelper = new ConversionHelper(method, arguments) + private lazy val conversionHelper = new ConversionHelper(method, arguments) @transient lazy val dataType = javaClassToDataType(method.getReturnType) @@ -119,14 +117,19 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre method.getGenericReturnType(), ObjectInspectorOptions.JAVA) @transient - protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray // TODO: Finish input output types. override def eval(input: InternalRow): Any = { - unwrap( - FunctionRegistry.invoke(method, function, conversionHelper - .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), - returnInspector) + val inputs = wrap(children.map(c => c.eval(input)), arguments, cached, inputDataTypes) + val ret = FunctionRegistry.invoke( + method, + function, + conversionHelper.convertIfNecessary(inputs : _*): _*) + unwrap(ret, returnInspector) } override def toString: String = { @@ -135,47 +138,48 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre } // Adapter from Catalyst ExpressionResult to Hive DeferredObject -private[hive] class DeferredObjectAdapter(oi: ObjectInspector) +private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataType) extends DeferredObject with HiveInspectors { + private var func: () => Any = _ def set(func: () => Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrap(func(), oi) + override def get(): AnyRef = wrap(func(), oi, dataType) } private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { - type UDFType = GenericUDF + + override def nullable: Boolean = true override def deterministic: Boolean = isUDFDeterministic - override def nullable: Boolean = true + override def foldable: Boolean = + isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient - lazy val function = funcWrapper.createFunction[UDFType]() + lazy val function = funcWrapper.createFunction[GenericUDF]() @transient - protected lazy val argumentInspectors = children.map(toInspector) + private lazy val argumentInspectors = children.map(toInspector) @transient - protected lazy val returnInspector = { + private lazy val returnInspector = { function.initializeAndFoldConstants(argumentInspectors.toArray) } @transient - protected lazy val isUDFDeterministic = { + private lazy val isUDFDeterministic = { val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) udfType != null && udfType.deterministic() } - override def foldable: Boolean = - isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] - @transient - protected lazy val deferedObjects = - argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] + private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) => + new DeferredObjectAdapter(inspect, child.dataType) + }.toArray[DeferredObject] lazy val dataType: DataType = inspectorToDataType(returnInspector) @@ -354,6 +358,9 @@ private[hive] case class HiveWindowFunction( // Output buffer. private var outputBuffer: Any = _ + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + override def init(): Unit = { evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) } @@ -368,8 +375,13 @@ private[hive] case class HiveWindowFunction( } override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length)) + wrap( + inputProjection(input), + inputInspectors, + new Array[AnyRef](children.length), + inputDataTypes) } + // Add input parameters for a single row. override def update(input: AnyRef): Unit = { evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) @@ -510,12 +522,15 @@ private[hive] case class HiveGenericUDTF( field => (inspectorToDataType(field.getFieldObjectInspector), true) } + @transient + private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { outputInspector // Make sure initialized. val inputProjection = new InterpretedProjection(children) - function.process(wrap(inputProjection(input), inputInspectors, udtInput)) + function.process(wrap(inputProjection(input), inputInspectors, udtInput, inputDataTypes)) collector.collectRows() } @@ -584,9 +599,12 @@ private[hive] case class HiveUDAFFunction( @transient protected lazy val cached = new Array[AnyRef](exprs.length) + @transient + private lazy val inputDataTypes: Array[DataType] = exprs.map(_.dataType).toArray + def update(input: InternalRow): Unit = { val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached)) + function.iterate(buffer, wrap(inputs, inspectors, cached, inputDataTypes)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 8bb498a06fc9e..0330013f5325e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -48,7 +48,11 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { ObjectInspectorOptions.JAVA).asInstanceOf[StructObjectInspector] val a = unwrap(state, soi).asInstanceOf[InternalRow] - val b = wrap(a, soi).asInstanceOf[UDAFPercentile.State] + + val dt = new StructType() + .add("counts", MapType(LongType, LongType)) + .add("percentiles", ArrayType(DoubleType)) + val b = wrap(a, soi, dt).asInstanceOf[UDAFPercentile.State] val sfCounts = soi.getStructFieldRef("counts") val sfPercentiles = soi.getStructFieldRef("percentiles") @@ -158,44 +162,45 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val writableOIs = dataTypes.map(toWritableInspector) val nullRow = data.map(d => null) - checkValues(nullRow, nullRow.zip(writableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(nullRow, nullRow.zip(writableOIs).zip(dataTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) // struct couldn't be constant, sweep it out val constantExprs = data.filter(!_.dataType.isInstanceOf[StructType]) + val constantTypes = constantExprs.map(_.dataType) val constantData = constantExprs.map(_.eval()) val constantNullData = constantData.map(_ => null) val constantWritableOIs = constantExprs.map(e => toWritableInspector(e.dataType)) val constantNullWritableOIs = constantExprs.map(e => toInspector(Literal.create(null, e.dataType))) - checkValues(constantData, constantData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantData, constantData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantData.zip(constantNullWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantData.zip(constantNullWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) - checkValues(constantNullData, constantNullData.zip(constantWritableOIs).map { - case (d, oi) => unwrap(wrap(d, oi), oi) + checkValues(constantNullData, constantNullData.zip(constantWritableOIs).zip(constantTypes).map { + case ((d, oi), dt) => unwrap(wrap(d, oi, dt), oi) }) } test("wrap / unwrap primitive writable object inspector") { val writableOIs = dataTypes.map(toWritableInspector) - checkValues(row, row.zip(writableOIs).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(writableOIs).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } test("wrap / unwrap primitive java object inspector") { val ois = dataTypes.map(toInspector) - checkValues(row, row.zip(ois).map { - case (data, oi) => unwrap(wrap(data, oi), oi) + checkValues(row, row.zip(ois).zip(dataTypes).map { + case ((data, oi), dt) => unwrap(wrap(data, oi, dt), oi) }) } @@ -205,31 +210,33 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { }) val inspector = toInspector(dt) checkValues(row, - unwrap(wrap(InternalRow.fromSeq(row), inspector), inspector).asInstanceOf[InternalRow]) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + unwrap(wrap(InternalRow.fromSeq(row), inspector, dt), inspector).asInstanceOf[InternalRow]) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) val d = row(0) :: row(0) :: Nil - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, - unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) checkValue(d, - unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } test("wrap / unwrap Map Type") { val dt = MapType(dataTypes(0), dataTypes(1)) val d = Map(row(0) -> row(1)) - checkValue(d, unwrap(wrap(d, toInspector(dt)), toInspector(dt))) - checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, - unwrap(wrap(d, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(d, toInspector(Literal.create(d, dt)), dt), toInspector(Literal.create(d, dt)))) checkValue(d, - unwrap(wrap(null, toInspector(Literal.create(d, dt))), toInspector(Literal.create(d, dt)))) + unwrap(wrap(null, toInspector(Literal.create(d, dt)), dt), + toInspector(Literal.create(d, dt)))) } } From fb5d43fb2529d78d55f1fe8d365191c946153640 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 26 Jul 2015 10:29:22 -0700 Subject: [PATCH 070/219] [SPARK-9356][SQL]Remove the internal use of DecimalType.Unlimited JIRA: https://issues.apache.org/jira/browse/SPARK-9356 Author: Yijie Shen Closes #7671 from yjshen/deprecated_unlimit and squashes the following commits: c707f56 [Yijie Shen] remove pattern matching in changePrecision 4a1823c [Yijie Shen] remove internal occurrence of Decimal.Unlimited --- .../spark/sql/catalyst/expressions/Cast.scala | 22 +++++++------------ .../expressions/NullFunctionsSuite.scala | 2 +- .../datasources/PartitioningUtils.scala | 3 +-- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e5b83cd31bf0f..e208262da96dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -507,20 +507,14 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def changePrecision(d: String, decimalType: DecimalType, - evPrim: String, evNull: String): String = { - decimalType match { - case DecimalType.Unlimited => - s"$evPrim = $d;" - case DecimalType.Fixed(precision, scale) => - s""" - if ($d.changePrecision($precision, $scale)) { - $evPrim = $d; - } else { - $evNull = true; - } - """ - } - } + evPrim: String, evNull: String): String = + s""" + if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { + $evPrim = $d; + } else { + $evNull = true; + } + """ private[this] def castToDecimalCode(from: DataType, target: DecimalType): CastFunction = { from match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index 9efe44c83293d..ace6c15dc8418 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -92,7 +92,7 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val nullOnly = Seq(Literal("x"), Literal.create(null, DoubleType), - Literal.create(null, DecimalType.Unlimited), + Literal.create(null, DecimalType.USER_DEFAULT), Literal(Float.MaxValue), Literal(false)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 9d0fa894b9942..66dfcc308ceca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -179,8 +179,7 @@ private[sql] object PartitioningUtils { * {{{ * NullType -> * IntegerType -> LongType -> - * DoubleType -> DecimalType.Unlimited -> - * StringType + * DoubleType -> StringType * }}} */ private[sql] def resolvePartitions( From 1cf19760d61a5a17bd175a906d34a2940141b76d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sun, 26 Jul 2015 13:03:13 -0700 Subject: [PATCH 071/219] [SPARK-9352] [SPARK-9353] Add tests for standalone scheduling code This also fixes a small issue in the standalone Master that was uncovered by the new tests. For more detail, read the description of SPARK-9353. Author: Andrew Or Closes #7668 from andrewor14/standalone-scheduling-tests and squashes the following commits: d852faf [Andrew Or] Add tests + fix scheduling with memory limits --- .../apache/spark/deploy/master/Master.scala | 8 +- .../spark/deploy/master/MasterSuite.scala | 199 +++++++++++++++++- 2 files changed, 202 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 029f94d1020be..51b3f0dead73e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -559,7 +559,7 @@ private[master] class Master( * allocated at a time, 12 cores from each worker would be assigned to each executor. * Since 12 < 16, no executors would launch [SPARK-8881]. */ - private[master] def scheduleExecutorsOnWorkers( + private def scheduleExecutorsOnWorkers( app: ApplicationInfo, usableWorkers: Array[WorkerInfo], spreadOutApps: Boolean): Array[Int] = { @@ -585,7 +585,11 @@ private[master] class Master( while (keepScheduling && canLaunchExecutor(pos) && coresToAssign >= coresPerExecutor) { coresToAssign -= coresPerExecutor assignedCores(pos) += coresPerExecutor - assignedMemory(pos) += memoryPerExecutor + // If cores per executor is not set, we are assigning 1 core at a time + // without actually meaning to launch 1 executor for each core assigned + if (app.desc.coresPerExecutor.isDefined) { + assignedMemory(pos) += memoryPerExecutor + } // Spreading out an application means spreading out its executors across as // many workers as possible. If we are not spreading out, then we should keep diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index a8fbaf1d9da0a..4d7016d1e594b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -25,14 +25,15 @@ import scala.language.postfixOps import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.scalatest.Matchers +import org.scalatest.{Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ +import org.apache.spark.rpc.RpcEnv -class MasterSuite extends SparkFunSuite with Matchers with Eventually { +class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester { test("can use a custom recovery mode factory") { val conf = new SparkConf(loadDefaults = false) @@ -142,4 +143,196 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { } } + test("basic scheduling - spread out") { + testBasicScheduling(spreadOut = true) + } + + test("basic scheduling - no spread out") { + testBasicScheduling(spreadOut = false) + } + + test("scheduling with max cores - spread out") { + testSchedulingWithMaxCores(spreadOut = true) + } + + test("scheduling with max cores - no spread out") { + testSchedulingWithMaxCores(spreadOut = false) + } + + test("scheduling with cores per executor - spread out") { + testSchedulingWithCoresPerExecutor(spreadOut = true) + } + + test("scheduling with cores per executor - no spread out") { + testSchedulingWithCoresPerExecutor(spreadOut = false) + } + + test("scheduling with cores per executor AND max cores - spread out") { + testSchedulingWithCoresPerExecutorAndMaxCores(spreadOut = true) + } + + test("scheduling with cores per executor AND max cores - no spread out") { + testSchedulingWithCoresPerExecutorAndMaxCores(spreadOut = false) + } + + private def testBasicScheduling(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo = makeAppInfo(1024) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + val scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 10) + assert(scheduledCores(2) === 10) + } + + private def testSchedulingWithMaxCores(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(1024, maxCores = Some(8)) + val appInfo2 = makeAppInfo(1024, maxCores = Some(16)) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + var scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo1, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + // With spreading out, each worker should be assigned a few cores + if (spreadOut) { + assert(scheduledCores(0) === 3) + assert(scheduledCores(1) === 3) + assert(scheduledCores(2) === 2) + } else { + // Without spreading out, the cores should be concentrated on the first worker + assert(scheduledCores(0) === 8) + assert(scheduledCores(1) === 0) + assert(scheduledCores(2) === 0) + } + // Now test the same thing with max cores > cores per worker + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo2, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 6) + assert(scheduledCores(1) === 5) + assert(scheduledCores(2) === 5) + } else { + // Without spreading out, the first worker should be fully booked, + // and the leftover cores should spill over to the second worker only. + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 6) + assert(scheduledCores(2) === 0) + } + } + + private def testSchedulingWithCoresPerExecutor(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(1024, coresPerExecutor = Some(2)) + val appInfo2 = makeAppInfo(256, coresPerExecutor = Some(2)) + val appInfo3 = makeAppInfo(256, coresPerExecutor = Some(3)) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + // Each worker should end up with 4 executors with 2 cores each + // This should be 4 because of the memory restriction on each worker + var scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo1, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 8) + assert(scheduledCores(1) === 8) + assert(scheduledCores(2) === 8) + // Now test the same thing without running into the worker memory limit + // Each worker should now end up with 5 executors with 2 cores each + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo2, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 10) + assert(scheduledCores(2) === 10) + // Now test the same thing with a cores per executor that 10 is not divisible by + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo3, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + assert(scheduledCores(0) === 9) + assert(scheduledCores(1) === 9) + assert(scheduledCores(2) === 9) + } + + // Sorry for the long method name! + private def testSchedulingWithCoresPerExecutorAndMaxCores(spreadOut: Boolean): Unit = { + val master = makeMaster() + val appInfo1 = makeAppInfo(256, coresPerExecutor = Some(2), maxCores = Some(4)) + val appInfo2 = makeAppInfo(256, coresPerExecutor = Some(2), maxCores = Some(20)) + val appInfo3 = makeAppInfo(256, coresPerExecutor = Some(3), maxCores = Some(20)) + val workerInfo = makeWorkerInfo(4096, 10) + val workerInfos = Array(workerInfo, workerInfo, workerInfo) + // We should only launch two executors, each with exactly 2 cores + var scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo1, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 2) + assert(scheduledCores(1) === 2) + assert(scheduledCores(2) === 0) + } else { + assert(scheduledCores(0) === 4) + assert(scheduledCores(1) === 0) + assert(scheduledCores(2) === 0) + } + // Test max cores > number of cores per worker + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo2, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 8) + assert(scheduledCores(1) === 6) + assert(scheduledCores(2) === 6) + } else { + assert(scheduledCores(0) === 10) + assert(scheduledCores(1) === 10) + assert(scheduledCores(2) === 0) + } + // Test max cores > number of cores per worker AND + // a cores per executor that is 10 is not divisible by + scheduledCores = master.invokePrivate( + _scheduleExecutorsOnWorkers(appInfo3, workerInfos, spreadOut)) + assert(scheduledCores.length === 3) + if (spreadOut) { + assert(scheduledCores(0) === 6) + assert(scheduledCores(1) === 6) + assert(scheduledCores(2) === 6) + } else { + assert(scheduledCores(0) === 9) + assert(scheduledCores(1) === 9) + assert(scheduledCores(2) === 0) + } + } + + // =============================== + // | Utility methods for testing | + // =============================== + + private val _scheduleExecutorsOnWorkers = PrivateMethod[Array[Int]]('scheduleExecutorsOnWorkers) + + private def makeMaster(conf: SparkConf = new SparkConf): Master = { + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 7077, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 8080, securityMgr, conf) + master + } + + private def makeAppInfo( + memoryPerExecutorMb: Int, + coresPerExecutor: Option[Int] = None, + maxCores: Option[Int] = None): ApplicationInfo = { + val desc = new ApplicationDescription( + "test", maxCores, memoryPerExecutorMb, null, "", None, None, coresPerExecutor) + val appId = System.currentTimeMillis.toString + new ApplicationInfo(0, appId, desc, new Date, null, Int.MaxValue) + } + + private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = { + val workerId = System.currentTimeMillis.toString + new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, 101, "address") + } + } From 6b2baec04fa3d928f0ee84af8c2723ac03a4648c Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Sun, 26 Jul 2015 13:35:16 -0700 Subject: [PATCH 072/219] [SPARK-9326] Close lock file used for file downloads. A lock file is used to ensure multiple executors running on the same machine don't download the same file concurrently. Spark never closes these lock files (releasing the lock does not close the underlying file); this commit fixes that. cc vanzin (looks like you've been involved in various other fixes surrounding these lock files) Author: Kay Ousterhout Closes #7650 from kayousterhout/SPARK-9326 and squashes the following commits: 0401bd1 [Kay Ousterhout] Close lock file used for file downloads. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c5816949cd360..c4012d0e83f7d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -443,11 +443,11 @@ private[spark] object Utils extends Logging { val lockFileName = s"${url.hashCode}${timestamp}_lock" val localDir = new File(getLocalDir(conf)) val lockFile = new File(localDir, lockFileName) - val raf = new RandomAccessFile(lockFile, "rw") + val lockFileChannel = new RandomAccessFile(lockFile, "rw").getChannel() // Only one executor entry. // The FileLock is only used to control synchronization for executors download file, // it's always safe regardless of lock type (mandatory or advisory). - val lock = raf.getChannel().lock() + val lock = lockFileChannel.lock() val cachedFile = new File(localDir, cachedFileName) try { if (!cachedFile.exists()) { @@ -455,6 +455,7 @@ private[spark] object Utils extends Logging { } } finally { lock.release() + lockFileChannel.close() } copyFile( url, From c025c3d0a1fdfbc45b64db9c871176b40b4a7b9b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 26 Jul 2015 16:49:19 -0700 Subject: [PATCH 073/219] [SPARK-9095] [SQL] Removes the old Parquet support This PR removes the old Parquet support: - Removes the old `ParquetRelation` together with related SQL configuration, plan nodes, strategies, utility classes, and test suites. - Renames `ParquetRelation2` to `ParquetRelation` - Renames `RowReadSupport` and `RowRecordMaterializer` to `CatalystReadSupport` and `CatalystRecordMaterializer` respectively, and moved them to separate files. This follows naming convention used in other Parquet data models implemented in parquet-mr. It should be easier for developers who are familiar with Parquet to follow. There's still some other code that can be cleaned up. Especially `RowWriteSupport`. But I'd like to leave this part to SPARK-8848. Author: Cheng Lian Closes #7441 from liancheng/spark-9095 and squashes the following commits: c7b6e38 [Cheng Lian] Removes WriteToFile 2d688d6 [Cheng Lian] Renames ParquetRelation2 to ParquetRelation ca9e1b7 [Cheng Lian] Removes old Parquet support --- .../plans/logical/basicOperators.scala | 6 - .../org/apache/spark/sql/DataFrame.scala | 9 +- .../apache/spark/sql/DataFrameReader.scala | 8 +- .../scala/org/apache/spark/sql/SQLConf.scala | 6 - .../org/apache/spark/sql/SQLContext.scala | 6 +- .../spark/sql/execution/SparkStrategies.scala | 58 +- .../sql/parquet/CatalystReadSupport.scala | 153 ++++ .../parquet/CatalystRecordMaterializer.scala | 41 + .../sql/parquet/CatalystSchemaConverter.scala | 5 + .../spark/sql/parquet/ParquetConverter.scala | 1 + .../spark/sql/parquet/ParquetRelation.scala | 843 +++++++++++++++--- .../sql/parquet/ParquetTableOperations.scala | 492 ---------- .../sql/parquet/ParquetTableSupport.scala | 151 +--- .../spark/sql/parquet/ParquetTypes.scala | 42 +- .../apache/spark/sql/parquet/newParquet.scala | 732 --------------- .../sql/parquet/ParquetFilterSuite.scala | 65 +- .../spark/sql/parquet/ParquetIOSuite.scala | 37 +- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../spark/sql/parquet/ParquetQuerySuite.scala | 27 +- .../sql/parquet/ParquetSchemaSuite.scala | 12 +- .../apache/spark/sql/hive/HiveContext.scala | 2 - .../spark/sql/hive/HiveMetastoreCatalog.scala | 22 +- .../spark/sql/hive/HiveStrategies.scala | 141 +-- .../spark/sql/hive/HiveParquetSuite.scala | 86 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 14 +- .../sql/hive/execution/SQLQuerySuite.scala | 54 +- .../apache/spark/sql/hive/parquetSuites.scala | 174 +--- 27 files changed, 1037 insertions(+), 2152 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 8e1a236e2988c..af68358daf5f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -186,12 +186,6 @@ case class WithWindowDefinition( override def output: Seq[Attribute] = child.output } -case class WriteToFile( - path: String, - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output -} - /** * @param order The ordering expressions * @param global True means global sorting apply for entire data set, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index fa942a1f8fd93..114ab91d10aa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -139,8 +139,7 @@ class DataFrame private[sql]( // happen right away to let these side effects take place eagerly. case _: Command | _: InsertIntoTable | - _: CreateTableUsingAsSelect | - _: WriteToFile => + _: CreateTableUsingAsSelect => LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) case _ => queryExecution.analyzed @@ -1615,11 +1614,7 @@ class DataFrame private[sql]( */ @deprecated("Use write.parquet(path)", "1.4.0") def saveAsParquetFile(path: String): Unit = { - if (sqlContext.conf.parquetUseDataSourceApi) { - write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) - } else { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } + write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index e9d782cdcd667..eb09807f9d9c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -21,16 +21,16 @@ import java.util.Properties import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, Partition} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.json.JSONRelation -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types.StructType +import org.apache.spark.{Logging, Partition} /** * :: Experimental :: @@ -259,7 +259,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { }.toArray sqlContext.baseRelationToDataFrame( - new ParquetRelation2( + new ParquetRelation( globbedPaths.map(_.toString), None, None, extraOptions.toMap)(sqlContext)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2a641b9d64a95..9b2dbd7442f5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -276,10 +276,6 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "Enables Parquet filter push-down optimization when set to true.") - val PARQUET_USE_DATA_SOURCE_API = booleanConf("spark.sql.parquet.useDataSourceApi", - defaultValue = Some(true), - doc = "") - val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( key = "spark.sql.parquet.followParquetFormatSpec", defaultValue = Some(false), @@ -456,8 +452,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) - private[spark] def parquetUseDataSourceApi: Boolean = getConf(PARQUET_USE_DATA_SOURCE_API) - private[spark] def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 49bfe74b680af..0e25e06e99ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -870,7 +870,6 @@ class SQLContext(@transient val sparkContext: SparkContext) LeftSemiJoin :: HashJoin :: InMemoryScans :: - ParquetOperations :: BasicOperators :: CartesianProduct :: BroadcastNestedLoopJoin :: Nil) @@ -1115,11 +1114,8 @@ class SQLContext(@transient val sparkContext: SparkContext) def parquetFile(paths: String*): DataFrame = { if (paths.isEmpty) { emptyDataFrame - } else if (conf.parquetUseDataSourceApi) { - read.parquet(paths : _*) } else { - DataFrame(this, parquet.ParquetRelation( - paths.mkString(","), Some(sparkContext.hadoopConfiguration), this)) + read.parquet(paths : _*) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eb4be1900b153..e2c7e8006f3b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,19 +17,18 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, Strategy, execution} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} -import org.apache.spark.sql.parquet._ +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.types._ +import org.apache.spark.sql.{SQLContext, Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => @@ -306,57 +305,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object ParquetOperations extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // TODO: need to support writing to other types of files. Unify the below code paths. - case logical.WriteToFile(path, child) => - val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, sqlContext) - // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite = false) :: Nil - case logical.InsertIntoTable( - table: ParquetRelation, partition, child, overwrite, ifNotExists) => - InsertIntoParquetTable(table, planLater(child), overwrite) :: Nil - case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => - val partitionColNames = relation.partitioningAttributes.map(_.name).toSet - val filtersToPush = filters.filter { pred => - val referencedColNames = pred.references.map(_.name).toSet - referencedColNames.intersect(partitionColNames).isEmpty - } - val prunePushedDownFilters = - if (sqlContext.conf.parquetFilterPushDown) { - (predicates: Seq[Expression]) => { - // Note: filters cannot be pushed down to Parquet if they contain more complex - // expressions than simple "Attribute cmp Literal" comparisons. Here we remove all - // filters that have been pushed down. Note that a predicate such as "(A AND B) OR C" - // can result in "A OR C" being pushed down. Here we are conservative in the sense - // that even if "A" was pushed and we check for "A AND B" we still want to keep - // "A AND B" in the higher-level filter, not just "B". - predicates.map(p => p -> ParquetFilters.createFilter(p)).collect { - case (predicate, None) => predicate - // Filter needs to be applied above when it contains partitioning - // columns - case (predicate, _) - if !predicate.references.map(_.name).toSet.intersect(partitionColNames).isEmpty => - predicate - } - } - } else { - identity[Seq[Expression]] _ - } - pruneFilterProject( - projectList, - filters, - prunePushedDownFilters, - ParquetTableScan( - _, - relation, - if (sqlContext.conf.parquetFilterPushDown) filtersToPush else Nil)) :: Nil - - case _ => Nil - } - } - object InMemoryScans extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala new file mode 100644 index 0000000000000..975fec101d9c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystReadSupport.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import java.util.{Map => JMap} + +import scala.collection.JavaConversions.{iterableAsScalaIterable, mapAsJavaMap, mapAsScalaMap} + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext +import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.io.api.RecordMaterializer +import org.apache.parquet.schema.MessageType + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType + +private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with Logging { + override def prepareForRead( + conf: Configuration, + keyValueMetaData: JMap[String, String], + fileSchema: MessageType, + readContext: ReadContext): RecordMaterializer[InternalRow] = { + log.debug(s"Preparing for read Parquet file with message type: $fileSchema") + + val toCatalyst = new CatalystSchemaConverter(conf) + val parquetRequestedSchema = readContext.getRequestedSchema + + val catalystRequestedSchema = + Option(readContext.getReadSupportMetadata).map(_.toMap).flatMap { metadata => + metadata + // First tries to read requested schema, which may result from projections + .get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) + // If not available, tries to read Catalyst schema from file metadata. It's only + // available if the target file is written by Spark SQL. + .orElse(metadata.get(CatalystReadSupport.SPARK_METADATA_KEY)) + }.map(StructType.fromString).getOrElse { + logDebug("Catalyst schema not available, falling back to Parquet schema") + toCatalyst.convert(parquetRequestedSchema) + } + + logDebug(s"Catalyst schema used to read Parquet files: $catalystRequestedSchema") + new CatalystRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema) + } + + override def init(context: InitContext): ReadContext = { + val conf = context.getConfiguration + + // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst + // schema of this file from its the metadata. + val maybeRowSchema = Option(conf.get(RowWriteSupport.SPARK_ROW_SCHEMA)) + + // Optional schema of requested columns, in the form of a string serialized from a Catalyst + // `StructType` containing all requested columns. + val maybeRequestedSchema = Option(conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) + + // Below we construct a Parquet schema containing all requested columns. This schema tells + // Parquet which columns to read. + // + // If `maybeRequestedSchema` is defined, we assemble an equivalent Parquet schema. Otherwise, + // we have to fallback to the full file schema which contains all columns in the file. + // Obviously this may waste IO bandwidth since it may read more columns than requested. + // + // Two things to note: + // + // 1. It's possible that some requested columns don't exist in the target Parquet file. For + // example, in the case of schema merging, the globally merged schema may contain extra + // columns gathered from other Parquet files. These columns will be simply filled with nulls + // when actually reading the target Parquet file. + // + // 2. When `maybeRequestedSchema` is available, we can't simply convert the Catalyst schema to + // Parquet schema using `CatalystSchemaConverter`, because the mapping is not unique due to + // non-standard behaviors of some Parquet libraries/tools. For example, a Parquet file + // containing a single integer array field `f1` may have the following legacy 2-level + // structure: + // + // message root { + // optional group f1 (LIST) { + // required INT32 element; + // } + // } + // + // while `CatalystSchemaConverter` may generate a standard 3-level structure: + // + // message root { + // optional group f1 (LIST) { + // repeated group list { + // required INT32 element; + // } + // } + // } + // + // Apparently, we can't use the 2nd schema to read the target Parquet file as they have + // different physical structures. + val parquetRequestedSchema = + maybeRequestedSchema.fold(context.getFileSchema) { schemaString => + val toParquet = new CatalystSchemaConverter(conf) + val fileSchema = context.getFileSchema.asGroupType() + val fileFieldNames = fileSchema.getFields.map(_.getName).toSet + + StructType + // Deserializes the Catalyst schema of requested columns + .fromString(schemaString) + .map { field => + if (fileFieldNames.contains(field.name)) { + // If the field exists in the target Parquet file, extracts the field type from the + // full file schema and makes a single-field Parquet schema + new MessageType("root", fileSchema.getType(field.name)) + } else { + // Otherwise, just resorts to `CatalystSchemaConverter` + toParquet.convert(StructType(Array(field))) + } + } + // Merges all single-field Parquet schemas to form a complete schema for all requested + // columns. Note that it's possible that no columns are requested at all (e.g., count + // some partition column of a partitioned Parquet table). That's why `fold` is used here + // and always fallback to an empty Parquet schema. + .fold(new MessageType("root")) { + _ union _ + } + } + + val metadata = + Map.empty[String, String] ++ + maybeRequestedSchema.map(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ + maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) + + logInfo(s"Going to read Parquet file with these requested columns: $parquetRequestedSchema") + new ReadContext(parquetRequestedSchema, metadata) + } +} + +private[parquet] object CatalystReadSupport { + val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" + + val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala new file mode 100644 index 0000000000000..84f1dccfeb788 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRecordMaterializer.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import org.apache.parquet.io.api.{GroupConverter, RecordMaterializer} +import org.apache.parquet.schema.MessageType + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType + +/** + * A [[RecordMaterializer]] for Catalyst rows. + * + * @param parquetSchema Parquet schema of the records to be read + * @param catalystSchema Catalyst schema of the rows to be constructed + */ +private[parquet] class CatalystRecordMaterializer( + parquetSchema: MessageType, catalystSchema: StructType) + extends RecordMaterializer[InternalRow] { + + private val rootConverter = new CatalystRowConverter(parquetSchema, catalystSchema, NoopUpdater) + + override def getCurrentRecord: InternalRow = rootConverter.currentRow + + override def getRootConverter: GroupConverter = rootConverter +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 1d3a0d15d336e..e9ef01e2dba1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -570,6 +570,11 @@ private[parquet] object CatalystSchemaConverter { """.stripMargin.split("\n").mkString(" ")) } + def checkFieldNames(schema: StructType): StructType = { + schema.fieldNames.foreach(checkFieldName) + schema + } + def analysisRequire(f: => Boolean, message: String): Unit = { if (!f) { throw new AnalysisException(message) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index be0a2029d233b..ea51650fe9039 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.parquet import org.apache.spark.sql.catalyst.InternalRow +// TODO Removes this while fixing SPARK-8848 private[sql] object CatalystConverter { // This is mostly Parquet convention (see, e.g., `ConversionPatterns`). // Note that "array" for the array elements is chosen by ParquetAvro. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 086559e9f7658..cc6fa2b88663f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -17,81 +17,720 @@ package org.apache.spark.sql.parquet -import java.io.IOException +import java.net.URI import java.util.logging.{Level, Logger => JLogger} +import java.util.{List => JList} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.permission.FsAction +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.{Failure, Try} + +import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetRecordReader, _} import org.apache.parquet.schema.MessageType import org.apache.parquet.{Log => ParquetLog} -import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.util.Utils +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.RDD._ +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} +import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.{SerializableConfiguration, Utils} + + +private[sql] class DefaultSource extends HadoopFsRelationProvider { + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new ParquetRelation(paths, schema, None, partitionColumns, parameters)(sqlContext) + } +} + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) + extends OutputWriterInternal { + + private val recordWriter: RecordWriter[Void, InternalRow] = { + val outputFormat = { + new ParquetOutputFormat[InternalRow]() { + // Here we override `getDefaultWorkFile` for two reasons: + // + // 1. To allow appending. We need to generate unique output file names to avoid + // overwriting existing files (either exist before the write job, or are just written + // by other tasks within the same write job). + // + // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses + // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all + // partitions in the case of dynamic partitioning. + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val split = context.getTaskAttemptID.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + } + } + + outputFormat.getRecordWriter(context) + } + + override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + + override def close(): Unit = recordWriter.close(context) +} + +private[sql] class ParquetRelation( + override val paths: Array[String], + private val maybeDataSchema: Option[StructType], + // This is for metastore conversion. + private val maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec) + with Logging { + + private[sql] def this( + paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String])( + sqlContext: SQLContext) = { + this( + paths, + maybeDataSchema, + maybePartitionSpec, + maybePartitionSpec.map(_.partitionColumns), + parameters)(sqlContext) + } + + // Should we merge schemas from all Parquet part-files? + private val shouldMergeSchemas = + parameters + .get(ParquetRelation.MERGE_SCHEMA) + .map(_.toBoolean) + .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + + private val maybeMetastoreSchema = parameters + .get(ParquetRelation.METASTORE_SCHEMA) + .map(DataType.fromJson(_).asInstanceOf[StructType]) + + private lazy val metadataCache: MetadataCache = { + val meta = new MetadataCache + meta.refresh() + meta + } -/** - * Relation that consists of data stored in a Parquet columnar format. - * - * Users should interact with parquet files though a [[DataFrame]], created by a [[SQLContext]] - * instead of using this class directly. - * - * {{{ - * val parquetRDD = sqlContext.parquetFile("path/to/parquet.file") - * }}} - * - * @param path The path to the Parquet file. - */ -private[sql] case class ParquetRelation( - path: String, - @transient conf: Option[Configuration], - @transient sqlContext: SQLContext, - partitioningAttributes: Seq[Attribute] = Nil) - extends LeafNode with MultiInstanceRelation { - - /** Schema derived from ParquetFile */ - def parquetSchema: MessageType = - ParquetTypesConverter - .readMetaData(new Path(path), conf) - .getFileMetaData - .getSchema - - /** Attributes */ - override val output = - partitioningAttributes ++ - ParquetTypesConverter.readSchemaFromFile( - new Path(path.split(",").head), - conf, - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetINT96AsTimestamp) - lazy val attributeMap = AttributeMap(output.map(o => o -> o)) - - override def newInstance(): this.type = { - ParquetRelation(path, conf, sqlContext).asInstanceOf[this.type] - } - - // Equals must also take into account the output attributes so that we can distinguish between - // different instances of the same relation, override def equals(other: Any): Boolean = other match { - case p: ParquetRelation => - p.path == path && p.output == output + case that: ParquetRelation => + val schemaEquality = if (shouldMergeSchemas) { + this.shouldMergeSchemas == that.shouldMergeSchemas + } else { + this.dataSchema == that.dataSchema && + this.schema == that.schema + } + + this.paths.toSet == that.paths.toSet && + schemaEquality && + this.maybeDataSchema == that.maybeDataSchema && + this.partitionColumns == that.partitionColumns + case _ => false } - override def hashCode: Int = { - com.google.common.base.Objects.hashCode(path, output) + override def hashCode(): Int = { + if (shouldMergeSchemas) { + Objects.hashCode( + Boolean.box(shouldMergeSchemas), + paths.toSet, + maybeDataSchema, + partitionColumns) + } else { + Objects.hashCode( + Boolean.box(shouldMergeSchemas), + paths.toSet, + dataSchema, + schema, + maybeDataSchema, + partitionColumns) + } + } + + /** Constraints on schema of dataframe to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to parquet format") + } + } + + override def dataSchema: StructType = { + val schema = maybeDataSchema.getOrElse(metadataCache.dataSchema) + // check if schema satisfies the constraints + // before moving forward + checkConstraints(schema) + schema + } + + override private[sql] def refresh(): Unit = { + super.refresh() + metadataCache.refresh() + } + + // Parquet data source always uses Catalyst internal representations. + override val needConversion: Boolean = false + + override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[ParquetOutputCommitter], + classOf[ParquetOutputCommitter]) + + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo("Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + + conf.setClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + committerClass, + classOf[ParquetOutputCommitter]) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + // TODO There's no need to use two kinds of WriteSupport + // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and + // complex types. + val writeSupportClass = + if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) + RowWriteSupport.setSchema(dataSchema.toAttributes, conf) + + // Sets compression scheme + conf.set( + ParquetOutputFormat.COMPRESSION, + ParquetRelation + .shortParquetCompressionCodecNames + .getOrElse( + sqlContext.conf.parquetCompressionCodec.toUpperCase, + CompressionCodecName.UNCOMPRESSED).name()) + + new OutputWriterFactory { + override def newInstance( + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, context) + } + } + } + + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) + val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown + val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + + // Create the function to set variable Parquet confs at both driver and executor side. + val initLocalJobFuncOpt = + ParquetRelation.initializeLocalJobFunc( + requiredColumns, + filters, + dataSchema, + useMetadataCache, + parquetFilterPushDown, + assumeBinaryIsString, + assumeInt96IsTimestamp, + followParquetFormatSpec) _ + + // Create the function to set input paths at the driver side. + val setInputPaths = ParquetRelation.initializeDriverSideJobFunc(inputFiles) _ + + Utils.withDummyCallSite(sqlContext.sparkContext) { + new SqlNewHadoopRDD( + sc = sqlContext.sparkContext, + broadcastedConf = broadcastedConf, + initDriverSideJobFuncOpt = Some(setInputPaths), + initLocalJobFuncOpt = Some(initLocalJobFuncOpt), + inputFormatClass = classOf[ParquetInputFormat[InternalRow]], + keyClass = classOf[Void], + valueClass = classOf[InternalRow]) { + + val cacheMetadata = useMetadataCache + + @transient val cachedStatuses = inputFiles.map { f => + // In order to encode the authority of a Path containing special characters such as '/' + // (which does happen in some S3N credentials), we need to use the string returned by the + // URI of the path to create a new Path. + val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) + new FileStatus( + f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, + f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) + }.toSeq + + private def escapePathUserInfo(path: Path): Path = { + val uri = path.toUri + new Path(new URI( + uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, + uri.getQuery, uri.getFragment)) + } + + // Overridden so we can inject our own cached files statuses. + override def getPartitions: Array[SparkPartition] = { + val inputFormat = new ParquetInputFormat[InternalRow] { + override def listStatus(jobContext: JobContext): JList[FileStatus] = { + if (cacheMetadata) cachedStatuses else super.listStatus(jobContext) + } + } + + val jobContext = newJobContext(getConf(isDriverSide = true), jobId) + val rawSplits = inputFormat.getSplits(jobContext) + + Array.tabulate[SparkPartition](rawSplits.size) { i => + new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + } + }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] + } } - // TODO: Use data from the footers. - override lazy val statistics = Statistics(sizeInBytes = sqlContext.conf.defaultSizeInBytes) + private class MetadataCache { + // `FileStatus` objects of all "_metadata" files. + private var metadataStatuses: Array[FileStatus] = _ + + // `FileStatus` objects of all "_common_metadata" files. + private var commonMetadataStatuses: Array[FileStatus] = _ + + // `FileStatus` objects of all data files (Parquet part-files). + var dataStatuses: Array[FileStatus] = _ + + // Schema of the actual Parquet files, without partition columns discovered from partition + // directory paths. + var dataSchema: StructType = null + + // Schema of the whole table, including partition columns. + var schema: StructType = _ + + // Cached leaves + var cachedLeaves: Set[FileStatus] = null + + /** + * Refreshes `FileStatus`es, footers, partition spec, and table schema. + */ + def refresh(): Unit = { + val currentLeafStatuses = cachedLeafStatuses() + + // Check if cachedLeafStatuses is changed or not + val leafStatusesChanged = (cachedLeaves == null) || + !cachedLeaves.equals(currentLeafStatuses) + + if (leafStatusesChanged) { + cachedLeaves = currentLeafStatuses.toIterator.toSet + + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. + val leaves = currentLeafStatuses.filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + }.toArray + + dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) + metadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + commonMetadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + + dataSchema = { + val dataSchema0 = maybeDataSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(throw new AnalysisException( + s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + + paths.mkString("\n\t"))) + + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // case insensitivity issue and possible schema mismatch (probably caused by schema + // evolution). + maybeMetastoreSchema + .map(ParquetRelation.mergeMetastoreParquetSchema(_, dataSchema0)) + .getOrElse(dataSchema0) + } + } + } + + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + private def readSchema(): Option[StructType] = { + // Sees which file(s) we need to touch in order to figure out the schema. + // + // Always tries the summary files first if users don't require a merged schema. In this case, + // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row + // groups information, and could be much smaller for large Parquet files with lots of row + // groups. If no summary file is available, falls back to some random part-file. + // + // NOTE: Metadata stored in the summary files are merged from all part-files. However, for + // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know + // how to merge them correctly if some key is associated with different values in different + // part-files. When this happens, Parquet simply gives up generating the summary file. This + // implies that if a summary file presents, then: + // + // 1. Either all part-files have exactly the same Spark SQL schema, or + // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus + // their schemas may differ from each other). + // + // Here we tend to be pessimistic and take the second case into account. Basically this means + // we can't trust the summary files if users require a merged schema, and must touch all part- + // files to do the merge. + val filesToTouch = + if (shouldMergeSchemas) { + // Also includes summary files, 'cause there might be empty partition directories. + (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + commonMetadataStatuses.headOption + // Falls back to "_metadata" + .orElse(metadataStatuses.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(dataStatuses.headOption) + .toSeq + } + + assert( + filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, + "No predefined schema found, " + + s"and no Parquet data files or summary files found under ${paths.mkString(", ")}.") + + ParquetRelation.mergeSchemasInParallel(filesToTouch, sqlContext) + } + } } -private[sql] object ParquetRelation { +private[sql] object ParquetRelation extends Logging { + // Whether we should merge schemas collected from all Parquet part-files. + private[sql] val MERGE_SCHEMA = "mergeSchema" + + // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used + // internally. + private[sql] val METASTORE_SCHEMA = "metastoreSchema" + + /** This closure sets various Parquet configurations at both driver side and executor side. */ + private[parquet] def initializeLocalJobFunc( + requiredColumns: Array[String], + filters: Array[Filter], + dataSchema: StructType, + useMetadataCache: Boolean, + parquetFilterPushDown: Boolean, + assumeBinaryIsString: Boolean, + assumeInt96IsTimestamp: Boolean, + followParquetFormatSpec: Boolean)(job: Job): Unit = { + val conf = job.getConfiguration + conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) + + // Try to push down filters when filter push-down is enabled. + if (parquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(dataSchema, _)) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) + } + + conf.set(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { + val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) + CatalystSchemaConverter.checkFieldNames(requestedSchema).json + }) + + conf.set( + RowWriteSupport.SPARK_ROW_SCHEMA, + CatalystSchemaConverter.checkFieldNames(dataSchema).json) + + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) + + // Sets flags for Parquet schema conversion + conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) + conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) + conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) + } + + /** This closure sets input paths at the driver side. */ + private[parquet] def initializeDriverSideJobFunc( + inputFiles: Array[FileStatus])(job: Job): Unit = { + // We side the input paths at the driver side. + logInfo(s"Reading Parquet file(s) from ${inputFiles.map(_.getPath).mkString(", ")}") + if (inputFiles.nonEmpty) { + FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) + } + } + + private[parquet] def readSchema( + footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { + + def parseParquetSchema(schema: MessageType): StructType = { + val converter = new CatalystSchemaConverter( + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.followParquetFormatSpec) + + converter.convert(schema) + } + + val seen = mutable.HashSet[String]() + val finalSchemas: Seq[StructType] = footers.flatMap { footer => + val metadata = footer.getParquetMetadata.getFileMetaData + val serializedSchema = metadata + .getKeyValueMetaData + .toMap + .get(CatalystReadSupport.SPARK_METADATA_KEY) + if (serializedSchema.isEmpty) { + // Falls back to Parquet schema if no Spark SQL schema found. + Some(parseParquetSchema(metadata.getSchema)) + } else if (!seen.contains(serializedSchema.get)) { + seen += serializedSchema.get + + // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to + // whatever is available. + Some(Try(DataType.fromJson(serializedSchema.get)) + .recover { case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(serializedSchema.get) + } + .recover { case cause: Throwable => + logWarning( + s"""Failed to parse serialized Spark schema in Parquet key-value metadata: + |\t$serializedSchema + """.stripMargin, + cause) + } + .map(_.asInstanceOf[StructType]) + .getOrElse { + // Falls back to Parquet schema if Spark SQL schema can't be parsed. + parseParquetSchema(metadata.getSchema) + }) + } else { + None + } + } + + finalSchemas.reduceOption { (left, right) => + try left.merge(right) catch { case e: Throwable => + throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + } + } + } + + /** + * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore + * schema and Parquet schema. + * + * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the + * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't + * distinguish binary and string). This method generates a correct schema by merging Metastore + * schema data types and Parquet schema field names. + */ + private[parquet] def mergeMetastoreParquetSchema( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + def schemaConflictMessage: String = + s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Parquet schema: + |${parquetSchema.prettyJson} + """.stripMargin + + val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) + + assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) + + val ordinalMap = metastoreSchema.zipWithIndex.map { + case (field, index) => field.name.toLowerCase -> index + }.toMap + + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) + + StructType(metastoreSchema.zip(reorderedParquetSchema).map { + // Uses Parquet field names but retains Metastore data types. + case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => + mSchema.copy(name = pSchema.name) + case _ => + throw new SparkException(schemaConflictMessage) + }) + } + + /** + * Returns the original schema from the Parquet file with any missing nullable fields from the + * Hive Metastore schema merged in. + * + * When constructing a DataFrame from a collection of structured data, the resulting object has + * a schema corresponding to the union of the fields present in each element of the collection. + * Spark SQL simply assigns a null value to any field that isn't present for a particular row. + * In some cases, it is possible that a given table partition stored as a Parquet file doesn't + * contain a particular nullable field in its schema despite that field being present in the + * table schema obtained from the Hive Metastore. This method returns a schema representing the + * Parquet file schema along with any additional nullable fields from the Metastore schema + * merged in. + */ + private[parquet] def mergeMissingNullableFields( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingFields = metastoreSchema + .map(_.name.toLowerCase) + .diff(parquetSchema.map(_.name.toLowerCase)) + .map(fieldMap(_)) + .filter(_.nullable) + StructType(parquetSchema ++ missingFields) + } + + /** + * Figures out a merged Parquet schema with a distributed Spark job. + * + * Note that locality is not taken into consideration here because: + * + * 1. For a single Parquet part-file, in most cases the footer only resides in the last block of + * that file. Thus we only need to retrieve the location of the last block. However, Hadoop + * `FileSystem` only provides API to retrieve locations of all blocks, which can be + * potentially expensive. + * + * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty + * slow. And basically locality is not available when using S3 (you can't run computation on + * S3 nodes). + */ + def mergeSchemasInParallel( + filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = { + val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) + + // HACK ALERT: + // + // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es + // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` + // but only `Writable`. What makes it worth, for some reason, `FileStatus` doesn't play well + // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These + // facts virtually prevents us to serialize `FileStatus`es. + // + // Since Parquet only relies on path and length information of those `FileStatus`es to read + // footers, here we just extract them (which can be easily serialized), send them to executor + // side, and resemble fake `FileStatus`es there. + val partialFileStatusInfo = filesToTouch.map(f => (f.getPath.toString, f.getLen)) + + // Issues a Spark job to read Parquet schema in parallel. + val partiallyMergedSchemas = + sqlContext + .sparkContext + .parallelize(partialFileStatusInfo) + .mapPartitions { iterator => + // Resembles fake `FileStatus`es with serialized path and length information. + val fakeFileStatuses = iterator.map { case (path, length) => + new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) + }.toSeq + + // Skips row group information since we only need the schema + val skipRowGroups = true + + // Reads footers in multi-threaded manner within each task + val footers = + ParquetFileReader.readAllFootersInParallel( + serializedConf.value, fakeFileStatuses, skipRowGroups) + + // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` + val converter = + new CatalystSchemaConverter( + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + footers.map { footer => + ParquetRelation.readSchemaFromFooter(footer, converter) + }.reduceOption(_ merge _).iterator + }.collect() + + partiallyMergedSchemas.reduceOption(_ merge _) + } + + /** + * Reads Spark SQL schema from a Parquet footer. If a valid serialized Spark SQL schema string + * can be found in the file metadata, returns the deserialized [[StructType]], otherwise, returns + * a [[StructType]] converted from the [[MessageType]] stored in this footer. + */ + def readSchemaFromFooter( + footer: Footer, converter: CatalystSchemaConverter): StructType = { + val fileMetaData = footer.getParquetMetadata.getFileMetaData + fileMetaData + .getKeyValueMetaData + .toMap + .get(CatalystReadSupport.SPARK_METADATA_KEY) + .flatMap(deserializeSchemaString) + .getOrElse(converter.convert(fileMetaData.getSchema)) + } + + private def deserializeSchemaString(schemaString: String): Option[StructType] = { + // Tries to deserialize the schema string as JSON first, then falls back to the case class + // string parser (data generated by older versions of Spark SQL uses this format). + Try(DataType.fromJson(schemaString).asInstanceOf[StructType]).recover { + case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(schemaString).asInstanceOf[StructType] + }.recoverWith { + case cause: Throwable => + logWarning( + "Failed to parse and ignored serialized Spark schema in " + + s"Parquet key-value metadata:\n\t$schemaString", cause) + Failure(cause) + }.toOption + } def enableLogForwarding() { // Note: the org.apache.parquet.Log class has a static initializer that @@ -127,12 +766,6 @@ private[sql] object ParquetRelation { JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) } - // The element type for the RDDs that this relation maps to. - type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow - - // The compression type - type CompressionType = org.apache.parquet.hadoop.metadata.CompressionCodecName - // The parquet compression short names val shortParquetCompressionCodecNames = Map( "NONE" -> CompressionCodecName.UNCOMPRESSED, @@ -140,82 +773,4 @@ private[sql] object ParquetRelation { "SNAPPY" -> CompressionCodecName.SNAPPY, "GZIP" -> CompressionCodecName.GZIP, "LZO" -> CompressionCodecName.LZO) - - /** - * Creates a new ParquetRelation and underlying Parquetfile for the given LogicalPlan. Note that - * this is used inside [[org.apache.spark.sql.execution.SparkStrategies SparkStrategies]] to - * create a resolved relation as a data sink for writing to a Parquetfile. The relation is empty - * but is initialized with ParquetMetadata and can be inserted into. - * - * @param pathString The directory the Parquetfile will be stored in. - * @param child The child node that will be used for extracting the schema. - * @param conf A configuration to be used. - * @return An empty ParquetRelation with inferred metadata. - */ - def create(pathString: String, - child: LogicalPlan, - conf: Configuration, - sqlContext: SQLContext): ParquetRelation = { - if (!child.resolved) { - throw new UnresolvedException[LogicalPlan]( - child, - "Attempt to create Parquet table from unresolved child (when schema is not available)") - } - createEmpty(pathString, child.output, false, conf, sqlContext) - } - - /** - * Creates an empty ParquetRelation and underlying Parquetfile that only - * consists of the Metadata for the given schema. - * - * @param pathString The directory the Parquetfile will be stored in. - * @param attributes The schema of the relation. - * @param conf A configuration to be used. - * @return An empty ParquetRelation. - */ - def createEmpty(pathString: String, - attributes: Seq[Attribute], - allowExisting: Boolean, - conf: Configuration, - sqlContext: SQLContext): ParquetRelation = { - val path = checkPath(pathString, allowExisting, conf) - conf.set(ParquetOutputFormat.COMPRESSION, shortParquetCompressionCodecNames.getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED) - .name()) - ParquetRelation.enableLogForwarding() - // This is a hack. We always set nullable/containsNull/valueContainsNull to true - // for the schema of a parquet data. - val schema = StructType.fromAttributes(attributes).asNullable - val newAttributes = schema.toAttributes - ParquetTypesConverter.writeMetaData(newAttributes, path, conf) - new ParquetRelation(path.toString, Some(conf), sqlContext) { - override val output = newAttributes - } - } - - private def checkPath(pathStr: String, allowExisting: Boolean, conf: Configuration): Path = { - if (pathStr == null) { - throw new IllegalArgumentException("Unable to create ParquetRelation: path is null") - } - val origPath = new Path(pathStr) - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"Unable to create ParquetRelation: incorrectly formatted path $pathStr") - } - val path = origPath.makeQualified(fs) - if (!allowExisting && fs.exists(path)) { - sys.error(s"File $pathStr already exists.") - } - - if (fs.exists(path) && - !fs.getFileStatus(path) - .getPermission - .getUserAction - .implies(FsAction.READ_WRITE)) { - throw new IOException( - s"Unable to create ParquetRelation: path $path not read-writable") - } - path - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala deleted file mode 100644 index 75cbbde4f1512..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ /dev/null @@ -1,492 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.io.IOException -import java.text.{NumberFormat, SimpleDateFormat} -import java.util.concurrent.TimeUnit -import java.util.Date - -import scala.collection.JavaConversions._ -import scala.util.Try - -import com.google.common.cache.CacheBuilder -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{BlockLocation, FileStatus, Path} -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat => NewFileOutputFormat} -import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.api.ReadSupport -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.schema.MessageType - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SQLConf -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, _} -import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} -import org.apache.spark.sql.types.StructType -import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.util.SerializableConfiguration - -/** - * :: DeveloperApi :: - * Parquet table scan operator. Imports the file that backs the given - * [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[InternalRow]``. - */ -private[sql] case class ParquetTableScan( - attributes: Seq[Attribute], - relation: ParquetRelation, - columnPruningPred: Seq[Expression]) - extends LeafNode { - - // The resolution of Parquet attributes is case sensitive, so we resolve the original attributes - // by exprId. note: output cannot be transient, see - // https://issues.apache.org/jira/browse/SPARK-1367 - val output = attributes.map(relation.attributeMap) - - // A mapping of ordinals partitionRow -> finalOutput. - val requestedPartitionOrdinals = { - val partitionAttributeOrdinals = AttributeMap(relation.partitioningAttributes.zipWithIndex) - - attributes.zipWithIndex.flatMap { - case (attribute, finalOrdinal) => - partitionAttributeOrdinals.get(attribute).map(_ -> finalOrdinal) - } - }.toArray - - protected override def doExecute(): RDD[InternalRow] = { - import org.apache.parquet.filter2.compat.FilterCompat.FilterPredicateCompat - - val sc = sqlContext.sparkContext - val job = new Job(sc.hadoopConfiguration) - ParquetInputFormat.setReadSupportClass(job, classOf[RowReadSupport]) - - val conf: Configuration = ContextUtil.getConfiguration(job) - - relation.path.split(",").foreach { curPath => - val qualifiedPath = { - val path = new Path(curPath) - path.getFileSystem(conf).makeQualified(path) - } - NewFileInputFormat.addInputPath(job, qualifiedPath) - } - - // Store both requested and original schema in `Configuration` - conf.set( - RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(output)) - conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(relation.output)) - - // Store record filtering predicate in `Configuration` - // Note 1: the input format ignores all predicates that cannot be expressed - // as simple column predicate filters in Parquet. Here we just record - // the whole pruning predicate. - ParquetFilters - .createRecordFilter(columnPruningPred) - .map(_.asInstanceOf[FilterPredicateCompat].getFilterPredicate) - // Set this in configuration of ParquetInputFormat, needed for RowGroupFiltering - .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) - - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.setBoolean( - SQLConf.PARQUET_CACHE_METADATA.key, - sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA, true)) - - // Use task side metadata in parquet - conf.setBoolean(ParquetInputFormat.TASK_SIDE_METADATA, true) - - val baseRDD = - new org.apache.spark.rdd.NewHadoopRDD( - sc, - classOf[FilteringParquetRowInputFormat], - classOf[Void], - classOf[InternalRow], - conf) - - if (requestedPartitionOrdinals.nonEmpty) { - // This check is based on CatalystConverter.createRootConverter. - val primitiveRow = output.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType)) - - // Uses temporary variable to avoid the whole `ParquetTableScan` object being captured into - // the `mapPartitionsWithInputSplit` closure below. - val outputSize = output.size - - baseRDD.mapPartitionsWithInputSplit { case (split, iter) => - val partValue = "([^=]+)=([^=]+)".r - val partValues = - split.asInstanceOf[org.apache.parquet.hadoop.ParquetInputSplit] - .getPath - .toString - .split("/") - .flatMap { - case partValue(key, value) => Some(key -> value) - case _ => None - }.toMap - - // Convert the partitioning attributes into the correct types - val partitionRowValues = - relation.partitioningAttributes - .map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) - - if (primitiveRow) { - new Iterator[InternalRow] { - def hasNext: Boolean = iter.hasNext - def next(): InternalRow = { - // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow. - val row = iter.next()._2.asInstanceOf[SpecificMutableRow] - - // Parquet will leave partitioning columns empty, so we fill them in here. - var i = 0 - while (i < requestedPartitionOrdinals.length) { - row(requestedPartitionOrdinals(i)._2) = - partitionRowValues(requestedPartitionOrdinals(i)._1) - i += 1 - } - row - } - } - } else { - // Create a mutable row since we need to fill in values from partition columns. - val mutableRow = new GenericMutableRow(outputSize) - new Iterator[InternalRow] { - def hasNext: Boolean = iter.hasNext - def next(): InternalRow = { - // We are using CatalystGroupConverter and it returns a GenericRow. - // Since GenericRow is not mutable, we just cast it to a Row. - val row = iter.next()._2.asInstanceOf[InternalRow] - - var i = 0 - while (i < row.numFields) { - mutableRow(i) = row.genericGet(i) - i += 1 - } - // Parquet will leave partitioning columns empty, so we fill them in here. - i = 0 - while (i < requestedPartitionOrdinals.length) { - mutableRow(requestedPartitionOrdinals(i)._2) = - partitionRowValues(requestedPartitionOrdinals(i)._1) - i += 1 - } - mutableRow - } - } - } - } - } else { - baseRDD.map(_._2) - } - } - - /** - * Applies a (candidate) projection. - * - * @param prunedAttributes The list of attributes to be used in the projection. - * @return Pruned TableScan. - */ - def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { - val success = validateProjection(prunedAttributes) - if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred) - } else { - sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") - } - } - - /** - * Evaluates a candidate projection by checking whether the candidate is a subtype - * of the original type. - * - * @param projection The candidate projection. - * @return True if the projection is valid, false otherwise. - */ - private def validateProjection(projection: Seq[Attribute]): Boolean = { - val original: MessageType = relation.parquetSchema - val candidate: MessageType = ParquetTypesConverter.convertFromAttributes(projection) - Try(original.checkContains(candidate)).isSuccess - } -} - -/** - * :: DeveloperApi :: - * Operator that acts as a sink for queries on RDDs and can be used to - * store the output inside a directory of Parquet files. This operator - * is similar to Hive's INSERT INTO TABLE operation in the sense that - * one can choose to either overwrite or append to a directory. Note - * that consecutive insertions to the same table must have compatible - * (source) schemas. - * - * WARNING: EXPERIMENTAL! InsertIntoParquetTable with overwrite=false may - * cause data corruption in the case that multiple users try to append to - * the same table simultaneously. Inserting into a table that was - * previously generated by other means (e.g., by creating an HDFS - * directory and importing Parquet files generated by other tools) may - * cause unpredicted behaviour and therefore results in a RuntimeException - * (only detected via filename pattern so will not catch all cases). - */ -@DeveloperApi -private[sql] case class InsertIntoParquetTable( - relation: ParquetRelation, - child: SparkPlan, - overwrite: Boolean = false) - extends UnaryNode with SparkHadoopMapReduceUtil { - - /** - * Inserts all rows into the Parquet file. - */ - protected override def doExecute(): RDD[InternalRow] = { - // TODO: currently we do not check whether the "schema"s are compatible - // That means if one first creates a table and then INSERTs data with - // and incompatible schema the execution will fail. It would be nice - // to catch this early one, maybe having the planner validate the schema - // before calling execute(). - - val childRdd = child.execute() - assert(childRdd != null) - - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - - val writeSupport = - if (child.output.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - log.debug("Initializing MutableRowWriteSupport") - classOf[org.apache.spark.sql.parquet.MutableRowWriteSupport] - } else { - classOf[org.apache.spark.sql.parquet.RowWriteSupport] - } - - ParquetOutputFormat.setWriteSupportClass(job, writeSupport) - - val conf = ContextUtil.getConfiguration(job) - // This is a hack. We always set nullable/containsNull/valueContainsNull to true - // for the schema of a parquet data. - val schema = StructType.fromAttributes(relation.output).asNullable - RowWriteSupport.setSchema(schema.toAttributes, conf) - - val fspath = new Path(relation.path) - val fs = fspath.getFileSystem(conf) - - if (overwrite) { - try { - fs.delete(fspath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${fspath.toString} prior" - + s" to InsertIntoParquetTable:\n${e.toString}") - } - } - saveAsHadoopFile(childRdd, relation.path.toString, conf) - - // We return the child RDD to allow chaining (alternatively, one could return nothing). - childRdd - } - - override def output: Seq[Attribute] = child.output - - /** - * Stores the given Row RDD as a Hadoop file. - * - * Note: We cannot use ``saveAsNewAPIHadoopFile`` from [[org.apache.spark.rdd.PairRDDFunctions]] - * together with [[org.apache.spark.util.MutablePair]] because ``PairRDDFunctions`` uses - * ``Tuple2`` and not ``Product2``. Also, we want to allow appending files to an existing - * directory and need to determine which was the largest written file index before starting to - * write. - * - * @param rdd The [[org.apache.spark.rdd.RDD]] to writer - * @param path The directory to write to. - * @param conf A [[org.apache.hadoop.conf.Configuration]]. - */ - private def saveAsHadoopFile( - rdd: RDD[InternalRow], - path: String, - conf: Configuration) { - val job = new Job(conf) - val keyType = classOf[Void] - job.setOutputKeyClass(keyType) - job.setOutputValueClass(classOf[InternalRow]) - NewFileOutputFormat.setOutputPath(job, new Path(path)) - val wrappedConf = new SerializableConfiguration(job.getConfiguration) - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - val jobtrackerID = formatter.format(new Date()) - val stageId = sqlContext.sparkContext.newRddId() - - val taskIdOffset = - if (overwrite) { - 1 - } else { - FileSystemHelper - .findMaxTaskId(NewFileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 - } - - def writeShard(context: TaskContext, iter: Iterator[InternalRow]): Int = { - /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, - context.attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = new AppendingParquetOutputFormat(taskIdOffset) - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - val writer = format.getRecordWriter(hadoopContext) - try { - while (iter.hasNext) { - val row = iter.next() - writer.write(null, row) - } - } finally { - writer.close(hadoopContext) - } - SparkHadoopMapRedUtil.commitTask(committer, hadoopContext, context) - 1 - } - val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) - /* apparently we need a TaskAttemptID to construct an OutputCommitter; - * however we're only going to use this local OutputCommitter for - * setupJob/commitJob, so we just use a dummy "map" task. - */ - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - jobCommitter.setupJob(jobTaskContext) - sqlContext.sparkContext.runJob(rdd, writeShard _) - jobCommitter.commitJob(jobTaskContext) - } -} - -/** - * TODO: this will be able to append to directories it created itself, not necessarily - * to imported ones. - */ -private[parquet] class AppendingParquetOutputFormat(offset: Int) - extends org.apache.parquet.hadoop.ParquetOutputFormat[InternalRow] { - // override to accept existing directories as valid output directory - override def checkOutputSpecs(job: JobContext): Unit = {} - var committer: OutputCommitter = null - - // override to choose output filename so not overwrite existing ones - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val taskId: TaskID = getTaskAttemptID(context).getTaskID - val partition: Int = taskId.getId - val filename = "part-r-" + numfmt.format(partition + offset) + ".parquet" - val committer: FileOutputCommitter = - getOutputCommitter(context).asInstanceOf[FileOutputCommitter] - new Path(committer.getWorkPath, filename) - } - - // The TaskAttemptContext is a class in hadoop-1 but is an interface in hadoop-2. - // The signatures of the method TaskAttemptContext.getTaskAttemptID for the both versions - // are the same, so the method calls are source-compatible but NOT binary-compatible because - // the opcode of method call for class is INVOKEVIRTUAL and for interface is INVOKEINTERFACE. - private def getTaskAttemptID(context: TaskAttemptContext): TaskAttemptID = { - context.getClass.getMethod("getTaskAttemptID").invoke(context).asInstanceOf[TaskAttemptID] - } - - // override to create output committer from configuration - override def getOutputCommitter(context: TaskAttemptContext): OutputCommitter = { - if (committer == null) { - val output = getOutputPath(context) - val cls = context.getConfiguration.getClass("spark.sql.parquet.output.committer.class", - classOf[ParquetOutputCommitter], classOf[ParquetOutputCommitter]) - val ctor = cls.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - committer = ctor.newInstance(output, context).asInstanceOf[ParquetOutputCommitter] - } - committer - } - - // FileOutputFormat.getOutputPath takes JobConf in hadoop-1 but JobContext in hadoop-2 - private def getOutputPath(context: TaskAttemptContext): Path = { - context.getConfiguration().get("mapred.output.dir") match { - case null => null - case name => new Path(name) - } - } -} - -// TODO Removes this class after removing old Parquet support code -/** - * We extend ParquetInputFormat in order to have more control over which - * RecordFilter we want to use. - */ -private[parquet] class FilteringParquetRowInputFormat - extends org.apache.parquet.hadoop.ParquetInputFormat[InternalRow] with Logging { - - override def createRecordReader( - inputSplit: InputSplit, - taskAttemptContext: TaskAttemptContext): RecordReader[Void, InternalRow] = { - - import org.apache.parquet.filter2.compat.FilterCompat.NoOpFilter - - val readSupport: ReadSupport[InternalRow] = new RowReadSupport() - - val filter = ParquetInputFormat.getFilter(ContextUtil.getConfiguration(taskAttemptContext)) - if (!filter.isInstanceOf[NoOpFilter]) { - new ParquetRecordReader[InternalRow]( - readSupport, - filter) - } else { - new ParquetRecordReader[InternalRow](readSupport) - } - } - -} - -private[parquet] object FileSystemHelper { - def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { - val origPath = new Path(pathStr) - val fs = origPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException( - s"ParquetTableOperations: Path $origPath is incorrectly formatted") - } - val path = origPath.makeQualified(fs) - if (!fs.exists(path) || !fs.getFileStatus(path).isDir) { - throw new IllegalArgumentException( - s"ParquetTableOperations: path $path does not exist or is not a directory") - } - fs.globStatus(path) - .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } - .map(_.getPath) - } - - /** - * Finds the maximum taskid in the output file names at the given path. - */ - def findMaxTaskId(pathStr: String, conf: Configuration): Int = { - val files = FileSystemHelper.listFiles(pathStr, conf) - // filename pattern is part-r-.parquet - val nameP = new scala.util.matching.Regex("""part-.-(\d{1,}).*""", "taskid") - val hiddenFileP = new scala.util.matching.Regex("_.*") - files.map(_.getName).map { - case nameP(taskid) => taskid.toInt - case hiddenFileP() => 0 - case other: String => - sys.error("ERROR: attempting to append to set of Parquet files and found file" + - s"that does not match name pattern: $other") - case _ => 0 - }.reduceOption(_ max _).getOrElse(0) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 7b6a7f65d69db..fc9f61a636768 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -18,18 +18,13 @@ package org.apache.spark.sql.parquet import java.nio.{ByteBuffer, ByteOrder} -import java.util import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions._ - import org.apache.hadoop.conf.Configuration import org.apache.parquet.column.ParquetProperties import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.parquet.hadoop.api.ReadSupport.ReadContext -import org.apache.parquet.hadoop.api.{InitContext, ReadSupport, WriteSupport} +import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.io.api._ -import org.apache.parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -38,147 +33,6 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -/** - * A [[RecordMaterializer]] for Catalyst rows. - * - * @param parquetSchema Parquet schema of the records to be read - * @param catalystSchema Catalyst schema of the rows to be constructed - */ -private[parquet] class RowRecordMaterializer(parquetSchema: MessageType, catalystSchema: StructType) - extends RecordMaterializer[InternalRow] { - - private val rootConverter = new CatalystRowConverter(parquetSchema, catalystSchema, NoopUpdater) - - override def getCurrentRecord: InternalRow = rootConverter.currentRow - - override def getRootConverter: GroupConverter = rootConverter -} - -private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logging { - override def prepareForRead( - conf: Configuration, - keyValueMetaData: util.Map[String, String], - fileSchema: MessageType, - readContext: ReadContext): RecordMaterializer[InternalRow] = { - log.debug(s"Preparing for read Parquet file with message type: $fileSchema") - - val toCatalyst = new CatalystSchemaConverter(conf) - val parquetRequestedSchema = readContext.getRequestedSchema - - val catalystRequestedSchema = - Option(readContext.getReadSupportMetadata).map(_.toMap).flatMap { metadata => - metadata - // First tries to read requested schema, which may result from projections - .get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) - // If not available, tries to read Catalyst schema from file metadata. It's only - // available if the target file is written by Spark SQL. - .orElse(metadata.get(RowReadSupport.SPARK_METADATA_KEY)) - }.map(StructType.fromString).getOrElse { - logDebug("Catalyst schema not available, falling back to Parquet schema") - toCatalyst.convert(parquetRequestedSchema) - } - - logDebug(s"Catalyst schema used to read Parquet files: $catalystRequestedSchema") - new RowRecordMaterializer(parquetRequestedSchema, catalystRequestedSchema) - } - - override def init(context: InitContext): ReadContext = { - val conf = context.getConfiguration - - // If the target file was written by Spark SQL, we should be able to find a serialized Catalyst - // schema of this file from its the metadata. - val maybeRowSchema = Option(conf.get(RowWriteSupport.SPARK_ROW_SCHEMA)) - - // Optional schema of requested columns, in the form of a string serialized from a Catalyst - // `StructType` containing all requested columns. - val maybeRequestedSchema = Option(conf.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA)) - - // Below we construct a Parquet schema containing all requested columns. This schema tells - // Parquet which columns to read. - // - // If `maybeRequestedSchema` is defined, we assemble an equivalent Parquet schema. Otherwise, - // we have to fallback to the full file schema which contains all columns in the file. - // Obviously this may waste IO bandwidth since it may read more columns than requested. - // - // Two things to note: - // - // 1. It's possible that some requested columns don't exist in the target Parquet file. For - // example, in the case of schema merging, the globally merged schema may contain extra - // columns gathered from other Parquet files. These columns will be simply filled with nulls - // when actually reading the target Parquet file. - // - // 2. When `maybeRequestedSchema` is available, we can't simply convert the Catalyst schema to - // Parquet schema using `CatalystSchemaConverter`, because the mapping is not unique due to - // non-standard behaviors of some Parquet libraries/tools. For example, a Parquet file - // containing a single integer array field `f1` may have the following legacy 2-level - // structure: - // - // message root { - // optional group f1 (LIST) { - // required INT32 element; - // } - // } - // - // while `CatalystSchemaConverter` may generate a standard 3-level structure: - // - // message root { - // optional group f1 (LIST) { - // repeated group list { - // required INT32 element; - // } - // } - // } - // - // Apparently, we can't use the 2nd schema to read the target Parquet file as they have - // different physical structures. - val parquetRequestedSchema = - maybeRequestedSchema.fold(context.getFileSchema) { schemaString => - val toParquet = new CatalystSchemaConverter(conf) - val fileSchema = context.getFileSchema.asGroupType() - val fileFieldNames = fileSchema.getFields.map(_.getName).toSet - - StructType - // Deserializes the Catalyst schema of requested columns - .fromString(schemaString) - .map { field => - if (fileFieldNames.contains(field.name)) { - // If the field exists in the target Parquet file, extracts the field type from the - // full file schema and makes a single-field Parquet schema - new MessageType("root", fileSchema.getType(field.name)) - } else { - // Otherwise, just resorts to `CatalystSchemaConverter` - toParquet.convert(StructType(Array(field))) - } - } - // Merges all single-field Parquet schemas to form a complete schema for all requested - // columns. Note that it's possible that no columns are requested at all (e.g., count - // some partition column of a partitioned Parquet table). That's why `fold` is used here - // and always fallback to an empty Parquet schema. - .fold(new MessageType("root")) { - _ union _ - } - } - - val metadata = - Map.empty[String, String] ++ - maybeRequestedSchema.map(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA -> _) ++ - maybeRowSchema.map(RowWriteSupport.SPARK_ROW_SCHEMA -> _) - - logInfo(s"Going to read Parquet file with these requested columns: $parquetRequestedSchema") - new ReadContext(parquetRequestedSchema, metadata) - } -} - -private[parquet] object RowReadSupport { - val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" - val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" - - private def getRequestedSchema(configuration: Configuration): Seq[Attribute] = { - val schemaString = configuration.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) - if (schemaString == null) null else ParquetTypesConverter.convertFromString(schemaString) - } -} - /** * A `parquet.hadoop.api.WriteSupport` for Row objects. */ @@ -190,7 +44,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo override def init(configuration: Configuration): WriteSupport.WriteContext = { val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) val metadata = new JHashMap[String, String]() - metadata.put(RowReadSupport.SPARK_METADATA_KEY, origAttributesStr) + metadata.put(CatalystReadSupport.SPARK_METADATA_KEY, origAttributesStr) if (attributes == null) { attributes = ParquetTypesConverter.convertFromString(origAttributesStr).toArray @@ -443,4 +297,3 @@ private[parquet] object RowWriteSupport { ParquetProperties.WriterVersion.PARQUET_1_0.toString) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index e748bd7857bd8..3854f5bd39fb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -53,15 +53,6 @@ private[parquet] object ParquetTypesConverter extends Logging { length } - def convertToAttributes( - parquetSchema: MessageType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - val converter = new CatalystSchemaConverter( - isBinaryAsString, isInt96AsTimestamp, followParquetFormatSpec = false) - converter.convert(parquetSchema).toAttributes - } - def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { val converter = new CatalystSchemaConverter() converter.convert(StructType.fromAttributes(attributes)) @@ -103,7 +94,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } val extraMetadata = new java.util.HashMap[String, String]() extraMetadata.put( - RowReadSupport.SPARK_METADATA_KEY, + CatalystReadSupport.SPARK_METADATA_KEY, ParquetTypesConverter.convertToString(attributes)) // TODO: add extra data, e.g., table name, date, etc.? @@ -165,35 +156,4 @@ private[parquet] object ParquetTypesConverter extends Logging { .getOrElse( throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) } - - /** - * Reads in Parquet Metadata from the given path and tries to extract the schema - * (Catalyst attributes) from the application-specific key-value map. If this - * is empty it falls back to converting from the Parquet file schema which - * may lead to an upcast of types (e.g., {byte, short} to int). - * - * @param origPath The path at which we expect one (or more) Parquet files. - * @param conf The Hadoop configuration to use. - * @return A list of attributes that make up the schema. - */ - def readSchemaFromFile( - origPath: Path, - conf: Option[Configuration], - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - val keyValueMetadata: java.util.Map[String, String] = - readMetaData(origPath, conf) - .getFileMetaData - .getKeyValueMetaData - if (keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY) != null) { - convertFromString(keyValueMetadata.get(RowReadSupport.SPARK_METADATA_KEY)) - } else { - val attributes = convertToAttributes( - readMetaData(origPath, conf).getFileMetaData.getSchema, - isBinaryAsString, - isInt96AsTimestamp) - log.info(s"Falling back to schema conversion from Parquet types; result: $attributes") - attributes - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala deleted file mode 100644 index 8ec228c2b25bc..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ /dev/null @@ -1,732 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.parquet - -import java.net.URI -import java.util.{List => JList} - -import scala.collection.JavaConversions._ -import scala.collection.mutable -import scala.util.{Failure, Try} - -import com.google.common.base.Objects -import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat -import org.apache.parquet.filter2.predicate.FilterApi -import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.hadoop.util.ContextUtil -import org.apache.parquet.schema.MessageType - -import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDD._ -import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} -import org.apache.spark.sql.execution.datasources.PartitionSpec -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} - - -private[sql] class DefaultSource extends HadoopFsRelationProvider { - override def createRelation( - sqlContext: SQLContext, - paths: Array[String], - schema: Option[StructType], - partitionColumns: Option[StructType], - parameters: Map[String, String]): HadoopFsRelation = { - new ParquetRelation2(paths, schema, None, partitionColumns, parameters)(sqlContext) - } -} - -// NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) - extends OutputWriterInternal { - - private val recordWriter: RecordWriter[Void, InternalRow] = { - val outputFormat = { - new ParquetOutputFormat[InternalRow]() { - // Here we override `getDefaultWorkFile` for two reasons: - // - // 1. To allow appending. We need to generate unique output file names to avoid - // overwriting existing files (either exist before the write job, or are just written - // by other tasks within the same write job). - // - // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses - // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all - // partitions in the case of dynamic partitioning. - override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") - val split = context.getTaskAttemptID.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") - } - } - } - - outputFormat.getRecordWriter(context) - } - - override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) - - override def close(): Unit = recordWriter.close(context) -} - -private[sql] class ParquetRelation2( - override val paths: Array[String], - private val maybeDataSchema: Option[StructType], - // This is for metastore conversion. - private val maybePartitionSpec: Option[PartitionSpec], - override val userDefinedPartitionColumns: Option[StructType], - parameters: Map[String, String])( - val sqlContext: SQLContext) - extends HadoopFsRelation(maybePartitionSpec) - with Logging { - - private[sql] def this( - paths: Array[String], - maybeDataSchema: Option[StructType], - maybePartitionSpec: Option[PartitionSpec], - parameters: Map[String, String])( - sqlContext: SQLContext) = { - this( - paths, - maybeDataSchema, - maybePartitionSpec, - maybePartitionSpec.map(_.partitionColumns), - parameters)(sqlContext) - } - - // Should we merge schemas from all Parquet part-files? - private val shouldMergeSchemas = - parameters - .get(ParquetRelation2.MERGE_SCHEMA) - .map(_.toBoolean) - .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) - - private val maybeMetastoreSchema = parameters - .get(ParquetRelation2.METASTORE_SCHEMA) - .map(DataType.fromJson(_).asInstanceOf[StructType]) - - private lazy val metadataCache: MetadataCache = { - val meta = new MetadataCache - meta.refresh() - meta - } - - override def equals(other: Any): Boolean = other match { - case that: ParquetRelation2 => - val schemaEquality = if (shouldMergeSchemas) { - this.shouldMergeSchemas == that.shouldMergeSchemas - } else { - this.dataSchema == that.dataSchema && - this.schema == that.schema - } - - this.paths.toSet == that.paths.toSet && - schemaEquality && - this.maybeDataSchema == that.maybeDataSchema && - this.partitionColumns == that.partitionColumns - - case _ => false - } - - override def hashCode(): Int = { - if (shouldMergeSchemas) { - Objects.hashCode( - Boolean.box(shouldMergeSchemas), - paths.toSet, - maybeDataSchema, - partitionColumns) - } else { - Objects.hashCode( - Boolean.box(shouldMergeSchemas), - paths.toSet, - dataSchema, - schema, - maybeDataSchema, - partitionColumns) - } - } - - /** Constraints on schema of dataframe to be stored. */ - private def checkConstraints(schema: StructType): Unit = { - if (schema.fieldNames.length != schema.fieldNames.distinct.length) { - val duplicateColumns = schema.fieldNames.groupBy(identity).collect { - case (x, ys) if ys.length > 1 => "\"" + x + "\"" - }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to parquet format") - } - } - - override def dataSchema: StructType = { - val schema = maybeDataSchema.getOrElse(metadataCache.dataSchema) - // check if schema satisfies the constraints - // before moving forward - checkConstraints(schema) - schema - } - - override private[sql] def refresh(): Unit = { - super.refresh() - metadataCache.refresh() - } - - // Parquet data source always uses Catalyst internal representations. - override val needConversion: Boolean = false - - override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum - - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = ContextUtil.getConfiguration(job) - - val committerClass = - conf.getClass( - SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, - classOf[ParquetOutputCommitter], - classOf[ParquetOutputCommitter]) - - if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { - logInfo("Using default output committer for Parquet: " + - classOf[ParquetOutputCommitter].getCanonicalName) - } else { - logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) - } - - conf.setClass( - SQLConf.OUTPUT_COMMITTER_CLASS.key, - committerClass, - classOf[ParquetOutputCommitter]) - - // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override - // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why - // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is - // bundled with `ParquetOutputFormat[Row]`. - job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) - - // TODO There's no need to use two kinds of WriteSupport - // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and - // complex types. - val writeSupportClass = - if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { - classOf[MutableRowWriteSupport] - } else { - classOf[RowWriteSupport] - } - - ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) - RowWriteSupport.setSchema(dataSchema.toAttributes, conf) - - // Sets compression scheme - conf.set( - ParquetOutputFormat.COMPRESSION, - ParquetRelation - .shortParquetCompressionCodecNames - .getOrElse( - sqlContext.conf.parquetCompressionCodec.toUpperCase, - CompressionCodecName.UNCOMPRESSED).name()) - - new OutputWriterFactory { - override def newInstance( - path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, context) - } - } - } - - override def buildScan( - requiredColumns: Array[String], - filters: Array[Filter], - inputFiles: Array[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { - val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) - val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown - val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString - val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec - - // Create the function to set variable Parquet confs at both driver and executor side. - val initLocalJobFuncOpt = - ParquetRelation2.initializeLocalJobFunc( - requiredColumns, - filters, - dataSchema, - useMetadataCache, - parquetFilterPushDown, - assumeBinaryIsString, - assumeInt96IsTimestamp, - followParquetFormatSpec) _ - - // Create the function to set input paths at the driver side. - val setInputPaths = ParquetRelation2.initializeDriverSideJobFunc(inputFiles) _ - - Utils.withDummyCallSite(sqlContext.sparkContext) { - new SqlNewHadoopRDD( - sc = sqlContext.sparkContext, - broadcastedConf = broadcastedConf, - initDriverSideJobFuncOpt = Some(setInputPaths), - initLocalJobFuncOpt = Some(initLocalJobFuncOpt), - inputFormatClass = classOf[ParquetInputFormat[InternalRow]], - keyClass = classOf[Void], - valueClass = classOf[InternalRow]) { - - val cacheMetadata = useMetadataCache - - @transient val cachedStatuses = inputFiles.map { f => - // In order to encode the authority of a Path containing special characters such as '/' - // (which does happen in some S3N credentials), we need to use the string returned by the - // URI of the path to create a new Path. - val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) - new FileStatus( - f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, - f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) - }.toSeq - - private def escapePathUserInfo(path: Path): Path = { - val uri = path.toUri - new Path(new URI( - uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, - uri.getQuery, uri.getFragment)) - } - - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = new ParquetInputFormat[InternalRow] { - override def listStatus(jobContext: JobContext): JList[FileStatus] = { - if (cacheMetadata) cachedStatuses else super.listStatus(jobContext) - } - } - - val jobContext = newJobContext(getConf(isDriverSide = true), jobId) - val rawSplits = inputFormat.getSplits(jobContext) - - Array.tabulate[SparkPartition](rawSplits.size) { i => - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - } - }.values.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] - } - } - - private class MetadataCache { - // `FileStatus` objects of all "_metadata" files. - private var metadataStatuses: Array[FileStatus] = _ - - // `FileStatus` objects of all "_common_metadata" files. - private var commonMetadataStatuses: Array[FileStatus] = _ - - // `FileStatus` objects of all data files (Parquet part-files). - var dataStatuses: Array[FileStatus] = _ - - // Schema of the actual Parquet files, without partition columns discovered from partition - // directory paths. - var dataSchema: StructType = null - - // Schema of the whole table, including partition columns. - var schema: StructType = _ - - // Cached leaves - var cachedLeaves: Set[FileStatus] = null - - /** - * Refreshes `FileStatus`es, footers, partition spec, and table schema. - */ - def refresh(): Unit = { - val currentLeafStatuses = cachedLeafStatuses() - - // Check if cachedLeafStatuses is changed or not - val leafStatusesChanged = (cachedLeaves == null) || - !cachedLeaves.equals(currentLeafStatuses) - - if (leafStatusesChanged) { - cachedLeaves = currentLeafStatuses.toIterator.toSet - - // Lists `FileStatus`es of all leaf nodes (files) under all base directories. - val leaves = currentLeafStatuses.filter { f => - isSummaryFile(f.getPath) || - !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) - }.toArray - - dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) - metadataStatuses = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) - commonMetadataStatuses = - leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) - - dataSchema = { - val dataSchema0 = maybeDataSchema - .orElse(readSchema()) - .orElse(maybeMetastoreSchema) - .getOrElse(throw new AnalysisException( - s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + - paths.mkString("\n\t"))) - - // If this Parquet relation is converted from a Hive Metastore table, must reconcile case - // case insensitivity issue and possible schema mismatch (probably caused by schema - // evolution). - maybeMetastoreSchema - .map(ParquetRelation2.mergeMetastoreParquetSchema(_, dataSchema0)) - .getOrElse(dataSchema0) - } - } - } - - private def isSummaryFile(file: Path): Boolean = { - file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || - file.getName == ParquetFileWriter.PARQUET_METADATA_FILE - } - - private def readSchema(): Option[StructType] = { - // Sees which file(s) we need to touch in order to figure out the schema. - // - // Always tries the summary files first if users don't require a merged schema. In this case, - // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row - // groups information, and could be much smaller for large Parquet files with lots of row - // groups. If no summary file is available, falls back to some random part-file. - // - // NOTE: Metadata stored in the summary files are merged from all part-files. However, for - // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know - // how to merge them correctly if some key is associated with different values in different - // part-files. When this happens, Parquet simply gives up generating the summary file. This - // implies that if a summary file presents, then: - // - // 1. Either all part-files have exactly the same Spark SQL schema, or - // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus - // their schemas may differ from each other). - // - // Here we tend to be pessimistic and take the second case into account. Basically this means - // we can't trust the summary files if users require a merged schema, and must touch all part- - // files to do the merge. - val filesToTouch = - if (shouldMergeSchemas) { - // Also includes summary files, 'cause there might be empty partition directories. - (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq - } else { - // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet - // don't have this. - commonMetadataStatuses.headOption - // Falls back to "_metadata" - .orElse(metadataStatuses.headOption) - // Summary file(s) not found, the Parquet file is either corrupted, or different part- - // files contain conflicting user defined metadata (two or more values are associated - // with a same key in different files). In either case, we fall back to any of the - // first part-file, and just assume all schemas are consistent. - .orElse(dataStatuses.headOption) - .toSeq - } - - assert( - filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, - "No predefined schema found, " + - s"and no Parquet data files or summary files found under ${paths.mkString(", ")}.") - - ParquetRelation2.mergeSchemasInParallel(filesToTouch, sqlContext) - } - } -} - -private[sql] object ParquetRelation2 extends Logging { - // Whether we should merge schemas collected from all Parquet part-files. - private[sql] val MERGE_SCHEMA = "mergeSchema" - - // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used - // internally. - private[sql] val METASTORE_SCHEMA = "metastoreSchema" - - /** This closure sets various Parquet configurations at both driver side and executor side. */ - private[parquet] def initializeLocalJobFunc( - requiredColumns: Array[String], - filters: Array[Filter], - dataSchema: StructType, - useMetadataCache: Boolean, - parquetFilterPushDown: Boolean, - assumeBinaryIsString: Boolean, - assumeInt96IsTimestamp: Boolean, - followParquetFormatSpec: Boolean)(job: Job): Unit = { - val conf = job.getConfiguration - conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[RowReadSupport].getName) - - // Try to push down filters when filter push-down is enabled. - if (parquetFilterPushDown) { - filters - // Collects all converted Parquet filter predicates. Notice that not all predicates can be - // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` - // is used here. - .flatMap(ParquetFilters.createFilter(dataSchema, _)) - .reduceOption(FilterApi.and) - .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) - } - - conf.set(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { - val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) - ParquetTypesConverter.convertToString(requestedSchema.toAttributes) - }) - - conf.set( - RowWriteSupport.SPARK_ROW_SCHEMA, - ParquetTypesConverter.convertToString(dataSchema.toAttributes)) - - // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata - conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) - - // Sets flags for Parquet schema conversion - conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) - conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) - conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) - } - - /** This closure sets input paths at the driver side. */ - private[parquet] def initializeDriverSideJobFunc( - inputFiles: Array[FileStatus])(job: Job): Unit = { - // We side the input paths at the driver side. - logInfo(s"Reading Parquet file(s) from ${inputFiles.map(_.getPath).mkString(", ")}") - if (inputFiles.nonEmpty) { - FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) - } - } - - private[parquet] def readSchema( - footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { - - def parseParquetSchema(schema: MessageType): StructType = { - StructType.fromAttributes( - // TODO Really no need to use `Attribute` here, we only need to know the data type. - ParquetTypesConverter.convertToAttributes( - schema, - sqlContext.conf.isParquetBinaryAsString, - sqlContext.conf.isParquetINT96AsTimestamp)) - } - - val seen = mutable.HashSet[String]() - val finalSchemas: Seq[StructType] = footers.flatMap { footer => - val metadata = footer.getParquetMetadata.getFileMetaData - val serializedSchema = metadata - .getKeyValueMetaData - .toMap - .get(RowReadSupport.SPARK_METADATA_KEY) - if (serializedSchema.isEmpty) { - // Falls back to Parquet schema if no Spark SQL schema found. - Some(parseParquetSchema(metadata.getSchema)) - } else if (!seen.contains(serializedSchema.get)) { - seen += serializedSchema.get - - // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to - // whatever is available. - Some(Try(DataType.fromJson(serializedSchema.get)) - .recover { case _: Throwable => - logInfo( - s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + - "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(serializedSchema.get) - } - .recover { case cause: Throwable => - logWarning( - s"""Failed to parse serialized Spark schema in Parquet key-value metadata: - |\t$serializedSchema - """.stripMargin, - cause) - } - .map(_.asInstanceOf[StructType]) - .getOrElse { - // Falls back to Parquet schema if Spark SQL schema can't be parsed. - parseParquetSchema(metadata.getSchema) - }) - } else { - None - } - } - - finalSchemas.reduceOption { (left, right) => - try left.merge(right) catch { case e: Throwable => - throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) - } - } - } - - /** - * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore - * schema and Parquet schema. - * - * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the - * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't - * distinguish binary and string). This method generates a correct schema by merging Metastore - * schema data types and Parquet schema field names. - */ - private[parquet] def mergeMetastoreParquetSchema( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - def schemaConflictMessage: String = - s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: - |${metastoreSchema.prettyJson} - | - |Parquet schema: - |${parquetSchema.prettyJson} - """.stripMargin - - val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) - - assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) - - val ordinalMap = metastoreSchema.zipWithIndex.map { - case (field, index) => field.name.toLowerCase -> index - }.toMap - - val reorderedParquetSchema = mergedParquetSchema.sortBy(f => - ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) - - StructType(metastoreSchema.zip(reorderedParquetSchema).map { - // Uses Parquet field names but retains Metastore data types. - case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => - mSchema.copy(name = pSchema.name) - case _ => - throw new SparkException(schemaConflictMessage) - }) - } - - /** - * Returns the original schema from the Parquet file with any missing nullable fields from the - * Hive Metastore schema merged in. - * - * When constructing a DataFrame from a collection of structured data, the resulting object has - * a schema corresponding to the union of the fields present in each element of the collection. - * Spark SQL simply assigns a null value to any field that isn't present for a particular row. - * In some cases, it is possible that a given table partition stored as a Parquet file doesn't - * contain a particular nullable field in its schema despite that field being present in the - * table schema obtained from the Hive Metastore. This method returns a schema representing the - * Parquet file schema along with any additional nullable fields from the Metastore schema - * merged in. - */ - private[parquet] def mergeMissingNullableFields( - metastoreSchema: StructType, - parquetSchema: StructType): StructType = { - val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap - val missingFields = metastoreSchema - .map(_.name.toLowerCase) - .diff(parquetSchema.map(_.name.toLowerCase)) - .map(fieldMap(_)) - .filter(_.nullable) - StructType(parquetSchema ++ missingFields) - } - - /** - * Figures out a merged Parquet schema with a distributed Spark job. - * - * Note that locality is not taken into consideration here because: - * - * 1. For a single Parquet part-file, in most cases the footer only resides in the last block of - * that file. Thus we only need to retrieve the location of the last block. However, Hadoop - * `FileSystem` only provides API to retrieve locations of all blocks, which can be - * potentially expensive. - * - * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty - * slow. And basically locality is not available when using S3 (you can't run computation on - * S3 nodes). - */ - def mergeSchemasInParallel( - filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = { - val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString - val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec - val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) - - // HACK ALERT: - // - // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es - // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` - // but only `Writable`. What makes it worth, for some reason, `FileStatus` doesn't play well - // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These - // facts virtually prevents us to serialize `FileStatus`es. - // - // Since Parquet only relies on path and length information of those `FileStatus`es to read - // footers, here we just extract them (which can be easily serialized), send them to executor - // side, and resemble fake `FileStatus`es there. - val partialFileStatusInfo = filesToTouch.map(f => (f.getPath.toString, f.getLen)) - - // Issues a Spark job to read Parquet schema in parallel. - val partiallyMergedSchemas = - sqlContext - .sparkContext - .parallelize(partialFileStatusInfo) - .mapPartitions { iterator => - // Resembles fake `FileStatus`es with serialized path and length information. - val fakeFileStatuses = iterator.map { case (path, length) => - new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) - }.toSeq - - // Skips row group information since we only need the schema - val skipRowGroups = true - - // Reads footers in multi-threaded manner within each task - val footers = - ParquetFileReader.readAllFootersInParallel( - serializedConf.value, fakeFileStatuses, skipRowGroups) - - // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` - val converter = - new CatalystSchemaConverter( - assumeBinaryIsString = assumeBinaryIsString, - assumeInt96IsTimestamp = assumeInt96IsTimestamp, - followParquetFormatSpec = followParquetFormatSpec) - - footers.map { footer => - ParquetRelation2.readSchemaFromFooter(footer, converter) - }.reduceOption(_ merge _).iterator - }.collect() - - partiallyMergedSchemas.reduceOption(_ merge _) - } - - /** - * Reads Spark SQL schema from a Parquet footer. If a valid serialized Spark SQL schema string - * can be found in the file metadata, returns the deserialized [[StructType]], otherwise, returns - * a [[StructType]] converted from the [[MessageType]] stored in this footer. - */ - def readSchemaFromFooter( - footer: Footer, converter: CatalystSchemaConverter): StructType = { - val fileMetaData = footer.getParquetMetadata.getFileMetaData - fileMetaData - .getKeyValueMetaData - .toMap - .get(RowReadSupport.SPARK_METADATA_KEY) - .flatMap(deserializeSchemaString) - .getOrElse(converter.convert(fileMetaData.getSchema)) - } - - private def deserializeSchemaString(schemaString: String): Option[StructType] = { - // Tries to deserialize the schema string as JSON first, then falls back to the case class - // string parser (data generated by older versions of Spark SQL uses this format). - Try(DataType.fromJson(schemaString).asInstanceOf[StructType]).recover { - case _: Throwable => - logInfo( - s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + - "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(schemaString).asInstanceOf[StructType] - }.recoverWith { - case cause: Throwable => - logWarning( - "Failed to parse and ignored serialized Spark schema in " + - s"Parquet key-value metadata:\n\t$schemaString", cause) - Failure(cause) - }.toOption - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 23df102cd951d..b6a7c4fbddbdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.parquet -import org.scalatest.BeforeAndAfterAll import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} @@ -40,7 +39,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} * 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred * data type is nullable. */ -class ParquetFilterSuiteBase extends QueryTest with ParquetTest { +class ParquetFilterSuite extends QueryTest with ParquetTest { lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext private def checkFilterPredicate( @@ -56,17 +55,9 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val maybeAnalyzedPredicate = { - val forParquetTableScan = query.queryExecution.executedPlan.collect { - case plan: ParquetTableScan => plan.columnPruningPred - }.flatten.reduceOption(_ && _) - - val forParquetDataSource = query.queryExecution.optimizedPlan.collect { - case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation2)) => filters - }.flatten.reduceOption(_ && _) - - forParquetTableScan.orElse(forParquetDataSource) - } + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation)) => filters + }.flatten.reduceOption(_ && _) assert(maybeAnalyzedPredicate.isDefined) maybeAnalyzedPredicate.foreach { pred => @@ -98,7 +89,7 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) (implicit df: DataFrame): Unit = { def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = { - assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { + assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).sorted) { df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } } @@ -308,18 +299,6 @@ class ParquetFilterSuiteBase extends QueryTest with ParquetTest { '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b))) } } -} - -class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } test("SPARK-6554: don't push down predicates which reference partition columns") { import sqlContext.implicits._ @@ -338,37 +317,3 @@ class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeA } } } - -class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll { - lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } - - test("SPARK-6742: don't push down predicates which reference partition columns") { - import sqlContext.implicits._ - - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { - withTempPath { dir => - val path = s"${dir.getCanonicalPath}/part=1" - (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) - - // If the "part = 1" filter gets pushed down, this query will throw an exception since - // "part" is not a valid column in the actual Parquet file - val df = DataFrame(sqlContext, org.apache.spark.sql.parquet.ParquetRelation( - path, - Some(sqlContext.sparkContext.hadoopConfiguration), sqlContext, - Seq(AttributeReference("part", IntegerType, false)()) )) - - checkAnswer( - df.filter("a = 1 or part = 1"), - (1 to 3).map(i => Row(1, i, i.toString))) - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 3a5b860484e86..b5314a3dd92e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -32,7 +32,6 @@ import org.apache.parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, P import org.apache.parquet.hadoop.{Footer, ParquetFileWriter, ParquetOutputCommitter, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} -import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql._ @@ -63,7 +62,7 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS /** * A test suite that tests basic Parquet I/O. */ -class ParquetIOSuiteBase extends QueryTest with ParquetTest { +class ParquetIOSuite extends QueryTest with ParquetTest { lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.implicits._ @@ -357,7 +356,7 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { """.stripMargin) withTempPath { location => - val extraMetadata = Map(RowReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) + val extraMetadata = Map(CatalystReadSupport.SPARK_METADATA_KEY -> sparkSchema.toString) val fileMetadata = new FileMetaData(parquetSchema, extraMetadata, "Spark") val path = new Path(location.getCanonicalPath) @@ -422,26 +421,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } } } -} - -class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) - extends ParquetOutputCommitter(outputPath, context) { - - override def commitJob(jobContext: JobContext): Unit = { - sys.error("Intentional exception for testing purposes") - } -} - -class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key, originalConf.toString) - } test("SPARK-6330 regression test") { // In 1.3.0, save to fs other than file: without configuring core-site.xml would get: @@ -456,14 +435,10 @@ class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterA } } -class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } +class BogusParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) + extends ParquetOutputCommitter(outputPath, context) { - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) + override def commitJob(jobContext: JobContext): Unit = { + sys.error("Intentional exception for testing purposes") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala index 7f16b1125c7a5..2eef10189f11c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala @@ -467,7 +467,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest { (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = sqlContext.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case LogicalRelation(relation: ParquetRelation2) => + case LogicalRelation(relation: ParquetRelation) => assert(relation.partitionSpec === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$queryExecution") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 21007d95ed752..c037faf4cfd92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.parquet import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} @@ -26,7 +25,7 @@ import org.apache.spark.sql.{QueryTest, Row, SQLConf} /** * A test suite that tests various Parquet queries. */ -class ParquetQuerySuiteBase extends QueryTest with ParquetTest { +class ParquetQuerySuite extends QueryTest with ParquetTest { lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.sql @@ -164,27 +163,3 @@ class ParquetQuerySuiteBase extends QueryTest with ParquetTest { } } } - -class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } -} - -class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll { - private lazy val originalConf = sqlContext.conf.parquetUseDataSourceApi - - override protected def beforeAll(): Unit = { - sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override protected def afterAll(): Unit = { - sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index fa629392674bd..4a0b3b60f419d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -378,7 +378,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructField("lowerCase", StringType), StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("lowercase", StringType), StructField("uppercase", DoubleType, nullable = false))), @@ -393,7 +393,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructType(Seq( StructField("UPPERCase", DoubleType, nullable = false)))) { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false))), @@ -404,7 +404,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Metastore schema contains additional non-nullable fields. assert(intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("uppercase", DoubleType, nullable = false), StructField("lowerCase", BinaryType, nullable = false))), @@ -415,7 +415,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Conflicting non-nullable field names intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq(StructField("lower", StringType, nullable = false))), StructType(Seq(StructField("lowerCase", BinaryType)))) } @@ -429,7 +429,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { StructField("firstField", StringType, nullable = true), StructField("secondField", StringType, nullable = true), StructField("thirdfield", StringType, nullable = true)))) { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), @@ -442,7 +442,7 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // Merge should fail if the Metastore contains any additional fields that are not // nullable. assert(intercept[Throwable] { - ParquetRelation2.mergeMetastoreParquetSchema( + ParquetRelation.mergeMetastoreParquetSchema( StructType(Seq( StructField("firstfield", StringType, nullable = true), StructField("secondfield", StringType, nullable = true), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 4cdb83c5116f9..1b8edefef4093 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -444,9 +444,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { HiveDDLStrategy, DDLStrategy, TakeOrderedAndProject, - ParquetOperations, InMemoryScans, - ParquetConversion, // Must be before HiveTableScans HiveTableScans, DataSinks, Scripts, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 0a2121c955871..262923531216f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConversions._ import com.google.common.base.Objects import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} - import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.metastore.Warehouse @@ -30,7 +29,6 @@ import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -39,10 +37,11 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, PartitionSpec, CreateTableUsingAsSelect, ResolvedDataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) @@ -260,8 +259,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // serialize the Metastore schema to JSON and pass it as a data source option because of the // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. val parquetOptions = Map( - ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json, - ParquetRelation2.MERGE_SCHEMA -> mergeSchema.toString) + ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) @@ -272,7 +271,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical@LogicalRelation(parquetRelation: ParquetRelation2) => + case logical@LogicalRelation(parquetRelation: ParquetRelation) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = @@ -317,7 +316,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, Some(partitionSpec)) val parquetRelation = cached.getOrElse { val created = LogicalRelation( - new ParquetRelation2( + new ParquetRelation( paths.toArray, None, Some(partitionSpec), parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created @@ -330,7 +329,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val cached = getCached(tableIdentifier, paths, metastoreSchema, None) val parquetRelation = cached.getOrElse { val created = LogicalRelation( - new ParquetRelation2(paths.toArray, None, None, parquetOptions)(hive)) + new ParquetRelation(paths.toArray, None, None, parquetOptions)(hive)) cachedDataSourceTables.put(tableIdentifier, created) created } @@ -370,8 +369,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive /** * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet * data source relations for better performance. - * - * This rule can be considered as [[HiveStrategies.ParquetConversion]] done right. */ object ParquetConversions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { @@ -386,7 +383,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - conf.parquetUseDataSourceApi && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) @@ -397,7 +393,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Inserting into partitioned table is not supported in Parquet data source (yet). if !relation.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - conf.parquetUseDataSourceApi && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) @@ -406,7 +401,6 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // Read path case p @ PhysicalOperation(_, _, relation: MetastoreRelation) if hive.convertMetastoreParquet && - conf.parquetUseDataSourceApi && relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => val parquetRelation = convertToParquetRelation(relation) val attributedRewrites = relation.output.zip(parquetRelation.output) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a22c3292eff94..cd6cd322c94ed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,23 +17,14 @@ package org.apache.spark.sql.hive -import scala.collection.JavaConversions._ - -import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.sql.types.StringType private[hive] trait HiveStrategies { @@ -42,136 +33,6 @@ private[hive] trait HiveStrategies { val hiveContext: HiveContext - /** - * :: Experimental :: - * Finds table scans that would use the Hive SerDe and replaces them with our own native parquet - * table scan operator. - * - * TODO: Much of this logic is duplicated in HiveTableScan. Ideally we would do some refactoring - * but since this is after the code freeze for 1.1 all logic is here to minimize disruption. - * - * Other issues: - * - Much of this logic assumes case insensitive resolution. - */ - @Experimental - object ParquetConversion extends Strategy { - implicit class LogicalPlanHacks(s: DataFrame) { - def lowerCase: DataFrame = DataFrame(s.sqlContext, s.logicalPlan) - - def addPartitioningAttributes(attrs: Seq[Attribute]): DataFrame = { - // Don't add the partitioning key if its already present in the data. - if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { - s - } else { - DataFrame( - s.sqlContext, - s.logicalPlan transform { - case p: ParquetRelation => p.copy(partitioningAttributes = attrs) - }) - } - } - } - - implicit class PhysicalPlanHacks(originalPlan: SparkPlan) { - def fakeOutput(newOutput: Seq[Attribute]): OutputFaker = - OutputFaker( - originalPlan.output.map(a => - newOutput.find(a.name.toLowerCase == _.name.toLowerCase) - .getOrElse( - sys.error(s"Can't find attribute $a to fake in set ${newOutput.mkString(",")}"))), - originalPlan) - } - - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projectList, predicates, relation: MetastoreRelation) - if relation.tableDesc.getSerdeClassName.contains("Parquet") && - hiveContext.convertMetastoreParquet && - !hiveContext.conf.parquetUseDataSourceApi => - - // Filter out all predicates that only deal with partition keys - val partitionsKeys = AttributeSet(relation.partitionKeys) - val (pruningPredicates, otherPredicates) = predicates.partition { - _.references.subsetOf(partitionsKeys) - } - - // We are going to throw the predicates and projection back at the whole optimization - // sequence so lets unresolve all the attributes, allowing them to be rebound to the - // matching parquet attributes. - val unresolvedOtherPredicates = Column(otherPredicates.map(_ transform { - case a: AttributeReference => UnresolvedAttribute(a.name) - }).reduceOption(And).getOrElse(Literal(true))) - - val unresolvedProjection: Seq[Column] = projectList.map(_ transform { - case a: AttributeReference => UnresolvedAttribute(a.name) - }).map(Column(_)) - - try { - if (relation.hiveQlTable.isPartitioned) { - val rawPredicate = pruningPredicates.reduceOption(And).getOrElse(Literal(true)) - // Translate the predicate so that it automatically casts the input values to the - // correct data types during evaluation. - val castedPredicate = rawPredicate transform { - case a: AttributeReference => - val idx = relation.partitionKeys.indexWhere(a.exprId == _.exprId) - val key = relation.partitionKeys(idx) - Cast(BoundReference(idx, StringType, nullable = true), key.dataType) - } - - val inputData = new GenericMutableRow(relation.partitionKeys.size) - val pruningCondition = - if (codegenEnabled) { - GeneratePredicate.generate(castedPredicate) - } else { - InterpretedPredicate.create(castedPredicate) - } - - val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part => - val partitionValues = part.getValues - var i = 0 - while (i < partitionValues.size()) { - inputData(i) = CatalystTypeConverters.convertToCatalyst(partitionValues(i)) - i += 1 - } - pruningCondition(inputData) - } - - val partitionLocations = partitions.map(_.getLocation) - - if (partitionLocations.isEmpty) { - PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil - } else { - hiveContext - .read.parquet(partitionLocations: _*) - .addPartitioningAttributes(relation.partitionKeys) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil - } - - } else { - hiveContext - .read.parquet(relation.hiveQlTable.getDataLocation.toString) - .lowerCase - .where(unresolvedOtherPredicates) - .select(unresolvedProjection: _*) - .queryExecution - .executedPlan - .fakeOutput(projectList.map(_.toAttribute)) :: Nil - } - } catch { - // parquetFile will throw an exception when there is no data. - // TODO: Remove this hack for Spark 1.3. - case iae: java.lang.IllegalArgumentException - if iae.getMessage.contains("Can not create a Path from an empty string") => - PhysicalRDD(plan.output, sparkContext.emptyRDD[InternalRow]) :: Nil - } - case _ => Nil - } - } - object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ScriptTransformation(input, script, output, child, schema: HiveScriptIOSchema) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index af68615e8e9d6..a45c2d957278f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.parquet.ParquetTest -import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) @@ -28,64 +28,54 @@ class HiveParquetSuite extends QueryTest with ParquetTest { import sqlContext._ - def run(prefix: String): Unit = { - test(s"$prefix: Case insensitive attribute names") { - withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { - val expected = (1 to 4).map(i => Row(i.toString)) - checkAnswer(sql("SELECT upper FROM cases"), expected) - checkAnswer(sql("SELECT LOWER FROM cases"), expected) - } + test("Case insensitive attribute names") { + withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") { + val expected = (1 to 4).map(i => Row(i.toString)) + checkAnswer(sql("SELECT upper FROM cases"), expected) + checkAnswer(sql("SELECT LOWER FROM cases"), expected) } + } - test(s"$prefix: SELECT on Parquet table") { - val data = (1 to 4).map(i => (i, s"val_$i")) - withParquetTable(data, "t") { - checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) - } + test("SELECT on Parquet table") { + val data = (1 to 4).map(i => (i, s"val_$i")) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT * FROM t"), data.map(Row.fromTuple)) } + } - test(s"$prefix: Simple column projection + filter on Parquet table") { - withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { - checkAnswer( - sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), - Seq(Row(true, "val_2"), Row(true, "val_4"))) - } + test("Simple column projection + filter on Parquet table") { + withParquetTable((1 to 4).map(i => (i % 2 == 0, i, s"val_$i")), "t") { + checkAnswer( + sql("SELECT `_1`, `_3` FROM t WHERE `_1` = true"), + Seq(Row(true, "val_2"), Row(true, "val_4"))) } + } - test(s"$prefix: Converting Hive to Parquet Table via saveAsParquetFile") { - withTempPath { dir => - sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) - read.parquet(dir.getCanonicalPath).registerTempTable("p") - withTempTable("p") { - checkAnswer( - sql("SELECT * FROM src ORDER BY key"), - sql("SELECT * from p ORDER BY key").collect().toSeq) - } + test("Converting Hive to Parquet Table via saveAsParquetFile") { + withTempPath { dir => + sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath) + read.parquet(dir.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + checkAnswer( + sql("SELECT * FROM src ORDER BY key"), + sql("SELECT * from p ORDER BY key").collect().toSeq) } } + } - test(s"$prefix: INSERT OVERWRITE TABLE Parquet table") { - withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { - withTempPath { file => - sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) - read.parquet(file.getCanonicalPath).registerTempTable("p") - withTempTable("p") { - // let's do three overwrites for good measure - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - sql("INSERT OVERWRITE TABLE p SELECT * FROM t") - checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) - } + test("INSERT OVERWRITE TABLE Parquet table") { + withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") { + withTempPath { file => + sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath) + read.parquet(file.getCanonicalPath).registerTempTable("p") + withTempTable("p") { + // let's do three overwrites for good measure + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + sql("INSERT OVERWRITE TABLE p SELECT * FROM t") + checkAnswer(sql("SELECT * FROM p"), sql("SELECT * FROM t").collect().toSeq) } } } } - - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { - run("Parquet data source enabled") - } - - withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "false") { - run("Parquet data source disabled") - } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e403f32efaf91..4fdf774ead75e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -21,10 +21,9 @@ import java.io.File import scala.collection.mutable.ArrayBuffer -import org.scalatest.BeforeAndAfterAll - import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred.InvalidInputException +import org.scalatest.BeforeAndAfterAll import org.apache.spark.Logging import org.apache.spark.sql._ @@ -33,7 +32,7 @@ import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -564,10 +563,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA } test("scan a parquet table created through a CTAS statement") { - withSQLConf( - HiveContext.CONVERT_METASTORE_PARQUET.key -> "true", - SQLConf.PARQUET_USE_DATA_SOURCE_API.key -> "true") { - + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "true") { withTempTable("jt") { (1 to 10).map(i => i -> s"str$i").toDF("a", "b").registerTempTable("jt") @@ -582,9 +578,9 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeA Row(3) :: Row(4) :: Nil) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(p: ParquetRelation2) => // OK + case LogicalRelation(p: ParquetRelation) => // OK case _ => - fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation2]}") + fail(s"test_parquet_ctas should have be converted to ${classOf[ParquetRelation]}") } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 03428265422e6..ff42fdefaa62a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -28,7 +28,8 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} -import org.apache.spark.sql.parquet.ParquetRelation2 +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ case class Nested1(f1: Nested2) @@ -61,7 +62,9 @@ class MyDialect extends DefaultParserDialect * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is * valid, but Hive currently cannot execute it. */ -class SQLQuerySuite extends QueryTest { +class SQLQuerySuite extends QueryTest with SQLTestUtils { + override def sqlContext: SQLContext = TestHive + test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") @@ -195,17 +198,17 @@ class SQLQuerySuite extends QueryTest { def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = { val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) relation match { - case LogicalRelation(r: ParquetRelation2) => + case LogicalRelation(r: ParquetRelation) => if (!isDataSourceParquet) { fail( s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " + - s"${ParquetRelation2.getClass.getCanonicalName}.") + s"${ParquetRelation.getClass.getCanonicalName}.") } case r: MetastoreRelation => if (isDataSourceParquet) { fail( - s"${ParquetRelation2.getClass.getCanonicalName} is expected, but found " + + s"${ParquetRelation.getClass.getCanonicalName} is expected, but found " + s"${classOf[MetastoreRelation].getCanonicalName}.") } } @@ -350,33 +353,26 @@ class SQLQuerySuite extends QueryTest { "serde_p1=p1", "serde_p2=p2", "tbl_p1=p11", "tbl_p2=p22", "MANAGED_TABLE" ) - val origUseParquetDataSource = conf.parquetUseDataSourceApi - try { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - sql( - """CREATE TABLE ctas5 - | STORED AS parquet AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin).collect() - - checkExistence(sql("DESC EXTENDED ctas5"), true, - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) - - val default = convertMetastoreParquet - // use the Hive SerDe for parquet tables - sql("set spark.sql.hive.convertMetastoreParquet = false") + sql( + """CREATE TABLE ctas5 + | STORED AS parquet AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin).collect() + + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + + // use the Hive SerDe for parquet tables + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { checkAnswer( sql("SELECT key, value FROM ctas5 ORDER BY key, value"), sql("SELECT key, value FROM src ORDER BY key, value").collect().toSeq) - sql(s"set spark.sql.hive.convertMetastoreParquet = $default") - } finally { - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, origUseParquetDataSource) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 82a8daf8b4b09..f56fb96c52d37 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -22,13 +22,13 @@ import java.io.File import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ -import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ -import org.apache.spark.sql.parquet.{ParquetRelation2, ParquetTableScan} +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -57,7 +57,7 @@ case class ParquetDataWithKeyAndComplexTypes( * A suite to test the automatic conversion of metastore tables with parquet data to use the * built in parquet support. */ -class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { +class ParquetMetastoreSuite extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() @@ -134,6 +134,19 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}' """) + sql( + """ + |create table test_parquet + |( + | intField INT, + | stringField STRING + |) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |STORED AS + | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + """.stripMargin) + (1 to 10).foreach { p => sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)") } @@ -166,6 +179,7 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { sql("DROP TABLE normal_parquet") sql("DROP TABLE IF EXISTS jt") sql("DROP TABLE IF EXISTS jt_array") + sql("DROP TABLE IF EXISTS test_parquet") setConf(HiveContext.CONVERT_METASTORE_PARQUET, false) } @@ -176,40 +190,9 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest { }.isEmpty) assert( sql("SELECT * FROM normal_parquet").queryExecution.executedPlan.collect { - case _: ParquetTableScan => true case _: PhysicalRDD => true }.nonEmpty) } -} - -class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - - sql( - """ - |create table test_parquet - |( - | intField INT, - | stringField STRING - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) - - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override def afterAll(): Unit = { - super.afterAll() - sql("DROP TABLE IF EXISTS test_parquet") - - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } test("scan an empty parquet table") { checkAnswer(sql("SELECT count(*) FROM test_parquet"), Row(0)) @@ -292,10 +275,10 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { ) table("test_parquet_ctas").queryExecution.optimizedPlan match { - case LogicalRelation(_: ParquetRelation2) => // OK + case LogicalRelation(_: ParquetRelation) => // OK case _ => fail( "test_parquet_ctas should be converted to " + - s"${classOf[ParquetRelation2].getCanonicalName}") + s"${classOf[ParquetRelation].getCanonicalName}") } sql("DROP TABLE IF EXISTS test_parquet_ctas") @@ -316,9 +299,9 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation2, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(_: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation2].getCanonicalName} and " + + s"${classOf[ParquetRelation].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + s"However, found a ${o.toString} ") } @@ -346,9 +329,9 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") df.queryExecution.executedPlan match { - case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation2, _, _)) => // OK + case ExecutedCommand(InsertIntoHadoopFsRelation(r: ParquetRelation, _, _)) => // OK case o => fail("test_insert_parquet should be converted to a " + - s"${classOf[ParquetRelation2].getCanonicalName} and " + + s"${classOf[ParquetRelation].getCanonicalName} and " + s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + s"However, found a ${o.toString} ") } @@ -379,17 +362,17 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { assertResult(2) { analyzed.collect { - case r @ LogicalRelation(_: ParquetRelation2) => r + case r @ LogicalRelation(_: ParquetRelation) => r }.size } sql("DROP TABLE ms_convert") } - def collectParquetRelation(df: DataFrame): ParquetRelation2 = { + def collectParquetRelation(df: DataFrame): ParquetRelation = { val plan = df.queryExecution.analyzed plan.collectFirst { - case LogicalRelation(r: ParquetRelation2) => r + case LogicalRelation(r: ParquetRelation) => r }.getOrElse { fail(s"Expecting a ParquetRelation2, but got:\n$plan") } @@ -439,7 +422,7 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { // Converted test_parquet should be cached. catalog.cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: ParquetRelation2) => // OK + case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + @@ -543,81 +526,10 @@ class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase { } } -class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override def afterAll(): Unit = { - super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } - - test("MetastoreRelation in InsertIntoTable will not be converted") { - sql( - """ - |create table test_insert_parquet - |( - | intField INT - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) - - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt") - df.queryExecution.executedPlan match { - case insert: execution.InsertIntoHiveTable => // OK - case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + - s"However, found ${o.toString}.") - } - - checkAnswer( - sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"), - sql("SELECT a FROM jt WHERE jt.a > 5").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") - } - - // TODO: enable it after the fix of SPARK-5950. - ignore("MetastoreRelation in InsertIntoHiveTable will not be converted") { - sql( - """ - |create table test_insert_parquet - |( - | int_array array - |) - |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - |STORED AS - | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' - | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' - """.stripMargin) - - val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array") - df.queryExecution.executedPlan match { - case insert: execution.InsertIntoHiveTable => // OK - case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " + - s"However, found ${o.toString}.") - } - - checkAnswer( - sql("SELECT int_array FROM test_insert_parquet"), - sql("SELECT a FROM jt_array").collect() - ) - - sql("DROP TABLE IF EXISTS test_insert_parquet") - } -} - /** * A suite of tests for the Parquet support through the data sources API. */ -class ParquetSourceSuiteBase extends ParquetPartitioningTest { +class ParquetSourceSuite extends ParquetPartitioningTest { override def beforeAll(): Unit = { super.beforeAll() @@ -712,20 +624,6 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest { } } } -} - -class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, true) - } - - override def afterAll(): Unit = { - super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } test("values in arrays and maps stored in parquet are always nullable") { val df = createDataFrame(Tuple2(Map(2 -> 3), Seq(4, 5, 6)) :: Nil).toDF("m", "a") @@ -734,7 +632,7 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { val expectedSchema1 = StructType( StructField("m", mapType1, nullable = true) :: - StructField("a", arrayType1, nullable = true) :: Nil) + StructField("a", arrayType1, nullable = true) :: Nil) assert(df.schema === expectedSchema1) df.write.format("parquet").saveAsTable("alwaysNullable") @@ -772,20 +670,6 @@ class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase { } } -class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase { - val originalConf = conf.parquetUseDataSourceApi - - override def beforeAll(): Unit = { - super.beforeAll() - conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, false) - } - - override def afterAll(): Unit = { - super.afterAll() - setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf) - } -} - /** * A collection of tests for parquet data with various forms of partitioning. */ From 1efe97dc9ed31e3b8727b81be633b7e96dd3cd34 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Sun, 26 Jul 2015 18:34:19 -0700 Subject: [PATCH 074/219] [SPARK-8867][SQL] Support list / describe function usage As Hive does, we need to list all of the registered UDF and its usage for user. We add the annotation to describe a UDF, so we can get the literal description info while registering the UDF. e.g. ```scala ExpressionDescription( usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", extended = """> SELECT _FUNC_('-1') 1""") case class Abs(child: Expression) extends UnaryArithmetic { ... ``` Author: Cheng Hao Closes #7259 from chenghao-intel/desc_function and squashes the following commits: cf29bba [Cheng Hao] fixing the code style issue 5193855 [Cheng Hao] Add more powerful parser for show functions c645a6b [Cheng Hao] fix bug in unit test 78d40f1 [Cheng Hao] update the padding issue for usage 48ee4b3 [Cheng Hao] update as feedback 70eb4e9 [Cheng Hao] add show/describe function support --- .../expressions/ExpressionDescription.java | 43 +++++++++++ .../catalyst/expressions/ExpressionInfo.java | 55 +++++++++++++ .../catalyst/analysis/FunctionRegistry.scala | 56 +++++++++++--- .../sql/catalyst/expressions/arithmetic.scala | 3 + .../expressions/stringOperations.scala | 6 ++ .../sql/catalyst/plans/logical/commands.scala | 28 ++++++- .../org/apache/spark/sql/SparkSQLParser.scala | 28 ++++++- .../spark/sql/execution/SparkStrategies.scala | 5 ++ .../apache/spark/sql/execution/commands.scala | 77 ++++++++++++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 +++++++ .../org/apache/spark/sql/hive/hiveUDFs.scala | 28 ++++++- .../hive/execution/HiveComparisonTest.scala | 6 +- .../sql/hive/execution/SQLQuerySuite.scala | 48 +++++++++++- 13 files changed, 389 insertions(+), 20 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionDescription.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionDescription.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionDescription.java new file mode 100644 index 0000000000000..9e10f27d59d55 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionDescription.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.annotation.DeveloperApi; + +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +/** + * ::DeveloperApi:: + + * A function description type which can be recognized by FunctionRegistry, and will be used to + * show the usage of the function in human language. + * + * `usage()` will be used for the function usage in brief way. + * `extended()` will be used for the function usage in verbose way, suppose + * an example will be provided. + * + * And we can refer the function name by `_FUNC_`, in `usage` and `extended`, as it's + * registered in `FunctionRegistry`. + */ +@DeveloperApi +@Retention(RetentionPolicy.RUNTIME) +public @interface ExpressionDescription { + String usage() default "_FUNC_ is undocumented"; + String extended() default "No example for _FUNC_."; +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java new file mode 100644 index 0000000000000..ba8e9cb4be28b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/ExpressionInfo.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +/** + * Expression information, will be used to describe a expression. + */ +public class ExpressionInfo { + private String className; + private String usage; + private String name; + private String extended; + + public String getClassName() { + return className; + } + + public String getUsage() { + return usage; + } + + public String getName() { + return name; + } + + public String getExtended() { + return extended; + } + + public ExpressionInfo(String className, String name, String usage, String extended) { + this.className = className; + this.name = name; + this.usage = usage; + this.extended = extended; + } + + public ExpressionInfo(String className, String name) { + this(className, name, null, null); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9c349838c28a1..aa05f448d12bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -30,26 +30,44 @@ import org.apache.spark.sql.catalyst.util.StringKeyHashMap /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { - def registerFunction(name: String, builder: FunctionBuilder): Unit + final def registerFunction(name: String, builder: FunctionBuilder): Unit = { + registerFunction(name, new ExpressionInfo(builder.getClass.getCanonicalName, name), builder) + } + + def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit @throws[AnalysisException]("If function does not exist") def lookupFunction(name: String, children: Seq[Expression]): Expression + + /* List all of the registered function names. */ + def listFunction(): Seq[String] + + /* Get the class of the registered function by specified name. */ + def lookupFunction(name: String): Option[ExpressionInfo] } class SimpleFunctionRegistry extends FunctionRegistry { - private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false) + private val functionBuilders = + StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) - override def registerFunction(name: String, builder: FunctionBuilder): Unit = { - functionBuilders.put(name, builder) + override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) + : Unit = { + functionBuilders.put(name, (info, builder)) } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - val func = functionBuilders.get(name).getOrElse { + val func = functionBuilders.get(name).map(_._2).getOrElse { throw new AnalysisException(s"undefined function $name") } func(children) } + + override def listFunction(): Seq[String] = functionBuilders.iterator.map(_._1).toList.sorted + + override def lookupFunction(name: String): Option[ExpressionInfo] = { + functionBuilders.get(name).map(_._1) + } } /** @@ -57,13 +75,22 @@ class SimpleFunctionRegistry extends FunctionRegistry { * functions are already filled in and the analyzer needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { - override def registerFunction(name: String, builder: FunctionBuilder): Unit = { + override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) + : Unit = { throw new UnsupportedOperationException } override def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } + + override def listFunction(): Seq[String] = { + throw new UnsupportedOperationException + } + + override def lookupFunction(name: String): Option[ExpressionInfo] = { + throw new UnsupportedOperationException + } } @@ -71,7 +98,7 @@ object FunctionRegistry { type FunctionBuilder = Seq[Expression] => Expression - val expressions: Map[String, FunctionBuilder] = Map( + val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions expression[Abs]("abs"), expression[CreateArray]("array"), @@ -205,13 +232,13 @@ object FunctionRegistry { val builtin: FunctionRegistry = { val fr = new SimpleFunctionRegistry - expressions.foreach { case (name, builder) => fr.registerFunction(name, builder) } + expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) } fr } /** See usage above. */ private def expression[T <: Expression](name: String) - (implicit tag: ClassTag[T]): (String, FunctionBuilder) = { + (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // See if we can find a constructor that accepts Seq[Expression] val varargCtor = Try(tag.runtimeClass.getDeclaredConstructor(classOf[Seq[_]])).toOption @@ -237,6 +264,15 @@ object FunctionRegistry { } } } - (name, builder) + + val clazz = tag.runtimeClass + val df = clazz.getAnnotation(classOf[ExpressionDescription]) + if (df != null) { + (name, + (new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended()), + builder)) + } else { + (name, (new ExpressionInfo(clazz.getCanonicalName, name), builder)) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 7c254a8750a9f..b37f530ec6814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -65,6 +65,9 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects /** * A function that get the absolute value of the numeric value. */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", + extended = "> SELECT _FUNC_('-1');\n1") case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes with CodegenFallback { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index cf187ad5a0a9f..38b0fb37dee3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -214,6 +214,9 @@ trait String2StringExpression extends ImplicitCastInputTypes { /** * A function that converts the characters of a string to uppercase. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns str with all characters changed to uppercase", + extended = "> SELECT _FUNC_('SparkSql');\n 'SPARKSQL'") case class Upper(child: Expression) extends UnaryExpression with String2StringExpression { @@ -227,6 +230,9 @@ case class Upper(child: Expression) /** * A function that converts the characters of a string to lowercase. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns str with all characters changed to lowercase", + extended = "> SELECT _FUNC_('SparkSql');\n'sparksql'") case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 246f4d7e34d3d..e6621e0f50a9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.types.StringType /** * A logical node that represents a non-query command to be executed by the system. For example, @@ -25,3 +26,28 @@ import org.apache.spark.sql.catalyst.expressions.Attribute * eagerly executed. */ trait Command + +/** + * Returned for the "DESCRIBE [EXTENDED] FUNCTION functionName" command. + * @param functionName The function to be described. + * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + */ +private[sql] case class DescribeFunction( + functionName: String, + isExtended: Boolean) extends LogicalPlan with Command { + + override def children: Seq[LogicalPlan] = Seq.empty + override val output: Seq[Attribute] = Seq( + AttributeReference("function_desc", StringType, nullable = false)()) +} + +/** + * Returned for the "SHOW FUNCTIONS" command, which will list all of the + * registered function list. + */ +private[sql] case class ShowFunctions( + db: Option[String], pattern: Option[String]) extends LogicalPlan with Command { + override def children: Seq[LogicalPlan] = Seq.empty + override val output: Seq[Attribute] = Seq( + AttributeReference("function", StringType, nullable = false)()) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala index e59fa6e162900..ea8fce6ca9cf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala @@ -21,7 +21,7 @@ import scala.util.parsing.combinator.RegexParsers import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{DescribeFunction, LogicalPlan, ShowFunctions} import org.apache.spark.sql.execution._ import org.apache.spark.sql.types.StringType @@ -57,6 +57,10 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val AS = Keyword("AS") protected val CACHE = Keyword("CACHE") protected val CLEAR = Keyword("CLEAR") + protected val DESCRIBE = Keyword("DESCRIBE") + protected val EXTENDED = Keyword("EXTENDED") + protected val FUNCTION = Keyword("FUNCTION") + protected val FUNCTIONS = Keyword("FUNCTIONS") protected val IN = Keyword("IN") protected val LAZY = Keyword("LAZY") protected val SET = Keyword("SET") @@ -65,7 +69,8 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr protected val TABLES = Keyword("TABLES") protected val UNCACHE = Keyword("UNCACHE") - override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others + override protected lazy val start: Parser[LogicalPlan] = + cache | uncache | set | show | desc | others private lazy val cache: Parser[LogicalPlan] = CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ { @@ -85,9 +90,24 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr case input => SetCommandParser(input) } + // It can be the following patterns: + // SHOW FUNCTIONS; + // SHOW FUNCTIONS mydb.func1; + // SHOW FUNCTIONS func1; + // SHOW FUNCTIONS `mydb.a`.`func1.aa`; private lazy val show: Parser[LogicalPlan] = - SHOW ~> TABLES ~ (IN ~> ident).? ^^ { - case _ ~ dbName => ShowTablesCommand(dbName) + ( SHOW ~> TABLES ~ (IN ~> ident).? ^^ { + case _ ~ dbName => ShowTablesCommand(dbName) + } + | SHOW ~ FUNCTIONS ~> ((ident <~ ".").? ~ (ident | stringLit)).? ^^ { + case Some(f) => ShowFunctions(f._1, Some(f._2)) + case None => ShowFunctions(None, None) + } + ) + + private lazy val desc: Parser[LogicalPlan] = + DESCRIBE ~ FUNCTION ~> EXTENDED.? ~ (ident | stringLit) ^^ { + case isExtended ~ functionName => DescribeFunction(functionName, isExtended.isDefined) } private lazy val others: Parser[LogicalPlan] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e2c7e8006f3b1..deeea3900c241 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -428,6 +428,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { ExecutedCommand( RunnableDescribeCommand(resultPlan, describe.output, isExtended)) :: Nil + case logical.ShowFunctions(db, pattern) => ExecutedCommand(ShowFunctions(db, pattern)) :: Nil + + case logical.DescribeFunction(function, extended) => + ExecutedCommand(DescribeFunction(function, extended)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index bace3f8a9c8d4..6b83025d5a153 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{ExpressionDescription, Expression, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -298,3 +298,78 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma rows } } + +/** + * A command for users to list all of the registered functions. + * The syntax of using this command in SQL is: + * {{{ + * SHOW FUNCTIONS + * }}} + * TODO currently we are simply ignore the db + */ +case class ShowFunctions(db: Option[String], pattern: Option[String]) extends RunnableCommand { + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("function", StringType, nullable = false) :: Nil) + + schema.toAttributes + } + + override def run(sqlContext: SQLContext): Seq[Row] = pattern match { + case Some(p) => + try { + val regex = java.util.regex.Pattern.compile(p) + sqlContext.functionRegistry.listFunction().filter(regex.matcher(_).matches()).map(Row(_)) + } catch { + // probably will failed in the regex that user provided, then returns empty row. + case _: Throwable => Seq.empty[Row] + } + case None => + sqlContext.functionRegistry.listFunction().map(Row(_)) + } +} + +/** + * A command for users to get the usage of a registered function. + * The syntax of using this command in SQL is + * {{{ + * DESCRIBE FUNCTION [EXTENDED] upper; + * }}} + */ +case class DescribeFunction( + functionName: String, + isExtended: Boolean) extends RunnableCommand { + + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("function_desc", StringType, nullable = false) :: Nil) + + schema.toAttributes + } + + private def replaceFunctionName(usage: String, functionName: String): String = { + if (usage == null) { + "To be added." + } else { + usage.replaceAll("_FUNC_", functionName) + } + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.functionRegistry.lookupFunction(functionName) match { + case Some(info) => + val result = + Row(s"Function: ${info.getName}") :: + Row(s"Class: ${info.getClassName}") :: + Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil + + if (isExtended) { + result :+ Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") + } else { + result + } + + case None => Seq(Row(s"Function: $functionName is not found.")) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cd386b7a3ecf9..8cef0b39f87dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.scalatest.BeforeAndAfterAll import java.sql.Timestamp @@ -58,6 +59,31 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { checkAnswer(queryCoalesce, Row("1") :: Nil) } + test("show functions") { + checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_))) + } + + test("describe functions") { + checkExistence(sql("describe function extended upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase", + "Extended Usage:", + "> SELECT upper('SparkSql');", + "'SPARKSQL'") + + checkExistence(sql("describe functioN Upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase") + + checkExistence(sql("describe functioN Upper"), false, + "Extended Usage") + + checkExistence(sql("describe functioN abcadf"), true, + "Function: abcadf is not found.") + } + test("SPARK-6743: no columns from cache") { Seq( (83, 0, 38), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 54bf6bd67ff84..8732e9abf8d31 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -76,8 +76,32 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) } } - override def registerFunction(name: String, builder: FunctionBuilder): Unit = - underlying.registerFunction(name, builder) + override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) + : Unit = underlying.registerFunction(name, info, builder) + + /* List all of the registered function names. */ + override def listFunction(): Seq[String] = { + val a = FunctionRegistry.getFunctionNames ++ underlying.listFunction() + a.toList.sorted + } + + /* Get the class of the registered function by specified name. */ + override def lookupFunction(name: String): Option[ExpressionInfo] = { + underlying.lookupFunction(name).orElse( + Try { + val info = FunctionRegistry.getFunctionInfo(name) + val annotation = info.getFunctionClass.getAnnotation(classOf[Description]) + if (annotation != null) { + Some(new ExpressionInfo( + info.getFunctionClass.getCanonicalName, + annotation.name(), + annotation.value(), + annotation.extended())) + } else { + None + } + }.getOrElse(None)) + } } private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index efb04bf3d5097..638b9c810372a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -370,7 +370,11 @@ abstract class HiveComparisonTest // Check that the results match unless its an EXPLAIN query. val preparedHive = prepareAnswer(hiveQuery, hive) - if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && preparedHive != catalyst) { + // We will ignore the ExplainCommand, ShowFunctions, DescribeFunction + if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && + (!hiveQuery.logical.isInstanceOf[ShowFunctions]) && + (!hiveQuery.logical.isInstanceOf[DescribeFunction]) && + preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive val catalystPrintOut = s"== CATALYST - ${catalyst.size} row(s) ==" +: catalyst diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ff42fdefaa62a..013936377b24c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.hive.execution import java.sql.{Date, Timestamp} +import scala.collection.JavaConversions._ + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHive @@ -138,6 +140,50 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { (1 to 6).map(_ => Row("CA", 20151))) } + test("show functions") { + val allFunctions = + (FunctionRegistry.builtin.listFunction().toSet[String] ++ + org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames).toList.sorted + checkAnswer(sql("SHOW functions"), allFunctions.map(Row(_))) + checkAnswer(sql("SHOW functions abs"), Row("abs")) + checkAnswer(sql("SHOW functions 'abs'"), Row("abs")) + checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) + checkAnswer(sql("SHOW functions `~`"), Row("~")) + checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) + checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) + // this probably will failed if we add more function with `sha` prefixing. + checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + } + + test("describe functions") { + // The Spark SQL built-in functions + checkExistence(sql("describe function extended upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase", + "Extended Usage:", + "> SELECT upper('SparkSql')", + "'SPARKSQL'") + + checkExistence(sql("describe functioN Upper"), true, + "Function: upper", + "Class: org.apache.spark.sql.catalyst.expressions.Upper", + "Usage: upper(str) - Returns str with all characters changed to uppercase") + + checkExistence(sql("describe functioN Upper"), false, + "Extended Usage") + + checkExistence(sql("describe functioN abcadf"), true, + "Function: abcadf is not found.") + + checkExistence(sql("describe functioN `~`"), true, + "Function: ~", + "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", + "Usage: ~ n - Bitwise not") + } + test("SPARK-5371: union with null and sum") { val df = Seq((1, 1)).toDF("c1", "c2") df.registerTempTable("table1") From 945d8bcbf67032edd7bdd201cf9f88c75b3464f7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 26 Jul 2015 22:13:37 -0700 Subject: [PATCH 075/219] [SPARK-9306] [SQL] Don't use SortMergeJoin when joining on unsortable columns JIRA: https://issues.apache.org/jira/browse/SPARK-9306 Author: Liang-Chi Hsieh Closes #7645 from viirya/smj_unsortable and squashes the following commits: a240707 [Liang-Chi Hsieh] Use forall instead of exists for readability. 55221fa [Liang-Chi Hsieh] Shouldn't use SortMergeJoin when joining on unsortable columns. --- .../sql/catalyst/planning/patterns.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 19 +++++++++++++++---- .../org/apache/spark/sql/JoinSuite.scala | 12 ++++++++++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b8e3b0d53a505..1e7b2a536ac12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -184,7 +184,7 @@ object PartialAggregation { * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { - /** (joinType, rightKeys, leftKeys, condition, leftChild, rightChild) */ + /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ type ReturnType = (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index deeea3900c241..306bbfec624c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -35,9 +35,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => + case ExtractEquiJoinKeys( + LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastLeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys @@ -90,6 +89,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } + private[this] def isValidSort( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): Boolean = { + leftKeys.zip(rightKeys).forall { keys => + (keys._1.dataType, keys._2.dataType) match { + case (l: AtomicType, r: AtomicType) => true + case (NullType, NullType) => true + case _ => false + } + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) @@ -100,7 +111,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // If the sort merge join option is set, we want to use sort merge join prior to hashjoin // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled => + if sqlContext.conf.sortMergeJoinEnabled && isValidSort(leftKeys, rightKeys) => val mergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8953889d1fae9..dfb2a7e099748 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -108,6 +108,18 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } } + test("SortMergeJoin shouldn't work on unsortable columns") { + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + try { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + Seq( + ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + } + } + test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") From aa80c64fcf9626b3720ee000a653db9266b74839 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Jul 2015 23:01:04 -0700 Subject: [PATCH 076/219] [SPARK-9368][SQL] Support get(ordinal, dataType) generic getter in UnsafeRow. Author: Reynold Xin Closes #7682 from rxin/unsaferow-generic-getter and squashes the following commits: 3063788 [Reynold Xin] Reset the change for real this time. 0f57c55 [Reynold Xin] Reset the changes in ExpressionEvalHelper. fb6ca30 [Reynold Xin] Support BinaryType. 24a3e46 [Reynold Xin] Added support for DateType/TimestampType. 9989064 [Reynold Xin] JoinedRow. 11f80a3 [Reynold Xin] [SPARK-9368][SQL] Support get(ordinal, dataType) generic getter in UnsafeRow. --- .../sql/catalyst/expressions/UnsafeRow.java | 52 ++++++++++++++++++- .../spark/sql/catalyst/InternalRow.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 2 +- .../expressions/SpecificMutableRow.scala | 2 +- .../codegen/GenerateProjection.scala | 2 +- .../spark/sql/catalyst/expressions/rows.scala | 4 +- 6 files changed, 58 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 87e5a89c19658..0fb33dd5a15a0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -24,7 +24,7 @@ import java.util.HashSet; import java.util.Set; -import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -235,6 +235,41 @@ public Object get(int ordinal) { throw new UnsupportedOperationException(); } + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof NullType) { + return null; + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + return getDecimal(ordinal); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType) dataType).size()); + } else { + throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); + } + } + @Override public boolean isNullAt(int ordinal) { assertIndexIsValid(ordinal); @@ -436,4 +471,19 @@ public String toString() { public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); } + + /** + * Writes the content of this row into a memory address, identified by an object and an offset. + * The target memory address must already been allocated, and have enough space to hold all the + * bytes in this string. + */ + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + baseObject, + baseOffset, + target, + targetOffset, + sizeInBytes + ); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 385d9671386dc..ad3977281d1a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -30,11 +30,11 @@ abstract class InternalRow extends Serializable { def numFields: Int - def get(ordinal: Int): Any + def get(ordinal: Int): Any = get(ordinal, null) def genericGet(ordinal: Int): Any = get(ordinal, null) - def get(ordinal: Int, dataType: DataType): Any = get(ordinal) + def get(ordinal: Int, dataType: DataType): Any def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index cc89d74146b34..27d6ff587ab71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -198,7 +198,7 @@ class JoinedRow extends InternalRow { if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } - override def get(i: Int): Any = + override def get(i: Int, dataType: DataType): Any = if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 5953a093dc684..b877ce47c083f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -219,7 +219,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).isNull = true } - override def get(i: Int): Any = values(i).boxed + override def get(i: Int, dataType: DataType): Any = values(i).boxed override def getStruct(ordinal: Int, numFields: Int): InternalRow = { values(ordinal).boxed.asInstanceOf[InternalRow] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index a361b216eb472..35920147105ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -183,7 +183,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { + public Object get(int i, ${classOf[DataType].getName} dataType) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index daeabe8e90f1d..b7c4ece4a16fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -99,7 +99,7 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) extends Internal override def numFields: Int = values.length - override def get(i: Int): Any = values(i) + override def get(i: Int, dataType: DataType): Any = values(i) override def getStruct(ordinal: Int, numFields: Int): InternalRow = { values(ordinal).asInstanceOf[InternalRow] @@ -130,7 +130,7 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow { override def numFields: Int = values.length - override def get(i: Int): Any = values(i) + override def get(i: Int, dataType: DataType): Any = values(i) override def getStruct(ordinal: Int, numFields: Int): InternalRow = { values(ordinal).asInstanceOf[InternalRow] From 4ffd3a1db5ecff653b02aa325786e734351c8bd2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 26 Jul 2015 23:58:03 -0700 Subject: [PATCH 077/219] [SPARK-9371][SQL] fix the support for special chars in column names for hive context Author: Wenchen Fan Closes #7684 from cloud-fan/hive and squashes the following commits: da21ffe [Wenchen Fan] fix the support for special chars in column names for hive context --- .../src/main/scala/org/apache/spark/sql/hive/HiveQl.scala | 6 +++--- .../apache/spark/sql/hive/execution/SQLQuerySuite.scala | 8 ++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 620b8a44d8a9b..2f79b0aad045c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1321,11 +1321,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Attribute References */ case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => - UnresolvedAttribute(cleanIdentifier(name)) + UnresolvedAttribute.quoted(cleanIdentifier(name)) case Token(".", qualifier :: Token(attr, Nil) :: Nil) => nodeToExpr(qualifier) match { - case UnresolvedAttribute(qualifierName) => - UnresolvedAttribute(qualifierName :+ cleanIdentifier(attr)) + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) case other => UnresolvedExtractValue(other, Literal(attr)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 013936377b24c..8371dd0716c06 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1067,4 +1067,12 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { ) TestHive.dropTempTable("test_SPARK8588") } + + test("SPARK-9371: fix the support for special chars in column names for hive context") { + TestHive.read.json(TestHive.sparkContext.makeRDD( + """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) + .registerTempTable("t") + + checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) + } } From 72981bc8f0d421e2563e2543a8c16a8cc76ad3aa Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 27 Jul 2015 17:15:35 +0800 Subject: [PATCH 078/219] [SPARK-7943] [SPARK-8105] [SPARK-8435] [SPARK-8714] [SPARK-8561] Fixes multi-database support This PR fixes a set of issues related to multi-database. A new data structure `TableIdentifier` is introduced to identify a table among multiple databases. We should stop using a single `String` (table name without database name), or `Seq[String]` (optional database name plus table name) to identify tables internally. Author: Cheng Lian Closes #7623 from liancheng/spark-8131-multi-db and squashes the following commits: f3bcd4b [Cheng Lian] Addresses PR comments e0eb76a [Cheng Lian] Fixes styling issues 41e2207 [Cheng Lian] Fixes multi-database support d4d1ec2 [Cheng Lian] Adds multi-database test cases --- .../apache/spark/sql/catalyst/SqlParser.scala | 14 ++ .../spark/sql/catalyst/TableIdentifier.scala | 31 ++++ .../spark/sql/catalyst/analysis/Catalog.scala | 9 +- .../apache/spark/sql/DataFrameWriter.scala | 83 ++++----- .../org/apache/spark/sql/SQLContext.scala | 6 +- .../spark/sql/execution/datasources/ddl.scala | 15 +- .../spark/sql/parquet/ParquetTest.scala | 4 +- .../apache/spark/sql/test/SQLTestUtils.scala | 29 +++- .../apache/spark/sql/hive/HiveContext.scala | 5 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 31 +++- .../spark/sql/hive/MultiDatabaseSuite.scala | 159 ++++++++++++++++++ .../apache/spark/sql/hive/orc/OrcTest.scala | 7 +- 12 files changed, 327 insertions(+), 66 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index c494e5d704213..b423f0fa04f69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -48,6 +48,15 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } + def parseTableIdentifier(input: String): TableIdentifier = { + // Initialize the Keywords. + initLexical + phrase(tableIdentifier)(new lexical.Scanner(input)) match { + case Success(ident, _) => ident + case failureOrError => sys.error(failureOrError.toString) + } + } + // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword` // properties via reflection the class in runtime for constructing the SqlLexical object protected val ALL = Keyword("ALL") @@ -444,4 +453,9 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest) } + + protected lazy val tableIdentifier: Parser[TableIdentifier] = + (ident <~ ".").? ~ ident ^^ { + case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala new file mode 100644 index 0000000000000..aebcdeb9d070f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +/** + * Identifies a `table` in `database`. If `database` is not defined, the current database is used. + */ +private[sql] case class TableIdentifier(table: String, database: Option[String] = None) { + def withDatabase(database: String): TableIdentifier = this.copy(database = Some(database)) + + def toSeq: Seq[String] = database.toSeq :+ table + + override def toString: String = toSeq.map("`" + _ + "`").mkString(".") + + def unquotedString: String = toSeq.mkString(".") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 1541491608b24..5766e6a2dd51a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -23,8 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.EmptyConf +import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} /** @@ -54,7 +53,7 @@ trait Catalog { */ def getTables(databaseName: Option[String]): Seq[(String, Boolean)] - def refreshTable(databaseName: String, tableName: String): Unit + def refreshTable(tableIdent: TableIdentifier): Unit def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit @@ -132,7 +131,7 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { result } - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } } @@ -241,7 +240,7 @@ object EmptyCatalog extends Catalog { override def unregisterAllTables(): Unit = {} - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { throw new UnsupportedOperationException } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 05da05d7b8050..7e3318cefe62c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.Properties import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} @@ -159,15 +160,19 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { - val partitions = - partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) - val overwrite = (mode == SaveMode.Overwrite) - df.sqlContext.executePlan(InsertIntoTable( - UnresolvedRelation(Seq(tableName)), - partitions.getOrElse(Map.empty[String, Option[String]]), - df.logicalPlan, - overwrite, - ifNotExists = false)).toRdd + insertInto(new SqlParser().parseTableIdentifier(tableName)) + } + + private def insertInto(tableIdent: TableIdentifier): Unit = { + val partitions = partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val overwrite = mode == SaveMode.Overwrite + df.sqlContext.executePlan( + InsertIntoTable( + UnresolvedRelation(tableIdent.toSeq), + partitions.getOrElse(Map.empty[String, Option[String]]), + df.logicalPlan, + overwrite, + ifNotExists = false)).toRdd } /** @@ -183,35 +188,37 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - if (df.sqlContext.catalog.tableExists(tableName :: Nil) && mode != SaveMode.Overwrite) { - mode match { - case SaveMode.Ignore => - // Do nothing - - case SaveMode.ErrorIfExists => - throw new AnalysisException(s"Table $tableName already exists.") - - case SaveMode.Append => - // If it is Append, we just ask insertInto to handle it. We will not use insertInto - // to handle saveAsTable with Overwrite because saveAsTable can change the schema of - // the table. But, insertInto with Overwrite requires the schema of data be the same - // the schema of the table. - insertInto(tableName) - - case SaveMode.Overwrite => - throw new UnsupportedOperationException("overwrite mode unsupported.") - } - } else { - val cmd = - CreateTableUsingAsSelect( - tableName, - source, - temporary = false, - partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), - mode, - extraOptions.toMap, - df.logicalPlan) - df.sqlContext.executePlan(cmd).toRdd + saveAsTable(new SqlParser().parseTableIdentifier(tableName)) + } + + private def saveAsTable(tableIdent: TableIdentifier): Unit = { + val tableExists = df.sqlContext.catalog.tableExists(tableIdent.toSeq) + + (tableExists, mode) match { + case (true, SaveMode.Ignore) => + // Do nothing + + case (true, SaveMode.ErrorIfExists) => + throw new AnalysisException(s"Table $tableIdent already exists.") + + case (true, SaveMode.Append) => + // If it is Append, we just ask insertInto to handle it. We will not use insertInto + // to handle saveAsTable with Overwrite because saveAsTable can change the schema of + // the table. But, insertInto with Overwrite requires the schema of data be the same + // the schema of the table. + insertInto(tableIdent) + + case _ => + val cmd = + CreateTableUsingAsSelect( + tableIdent.unquotedString, + source, + temporary = false, + partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + mode, + extraOptions.toMap, + df.logicalPlan) + df.sqlContext.executePlan(cmd).toRdd } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0e25e06e99ab2..dbb2a09846548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -798,8 +798,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group ddl_ops * @since 1.3.0 */ - def table(tableName: String): DataFrame = - DataFrame(this, catalog.lookupRelation(Seq(tableName))) + def table(tableName: String): DataFrame = { + val tableIdent = new SqlParser().parseTableIdentifier(tableName) + DataFrame(this, catalog.lookupRelation(tableIdent.toSeq)) + } /** * Returns a [[DataFrame]] containing names of existing tables in the current database. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 1f2797ec5527a..e73b3704d4dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -21,16 +21,17 @@ import scala.language.{existentials, implicitConversions} import scala.util.matching.Regex import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} -import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SQLContext, SaveMode} import org.apache.spark.util.Utils /** @@ -151,7 +152,7 @@ private[sql] class DDLParser( protected lazy val refreshTable: Parser[LogicalPlan] = REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ { case maybeDatabaseName ~ tableName => - RefreshTable(maybeDatabaseName.getOrElse("default"), tableName) + RefreshTable(TableIdentifier(tableName, maybeDatabaseName)) } protected lazy val options: Parser[Map[String, String]] = @@ -442,16 +443,16 @@ private[sql] case class CreateTempTableUsingAsSelect( } } -private[sql] case class RefreshTable(databaseName: String, tableName: String) +private[sql] case class RefreshTable(tableIdent: TableIdentifier) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { // Refresh the given table's metadata first. - sqlContext.catalog.refreshTable(databaseName, tableName) + sqlContext.catalog.refreshTable(tableIdent) // If this table is cached as a InMemoryColumnarRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sqlContext.catalog.lookupRelation(Seq(databaseName, tableName)) + val logicalPlan = sqlContext.catalog.lookupRelation(tableIdent.toSeq) // Use lookupCachedData directly since RefreshTable also takes databaseName. val isCached = sqlContext.cacheManager.lookupCachedData(logicalPlan).nonEmpty if (isCached) { @@ -461,7 +462,7 @@ private[sql] case class RefreshTable(databaseName: String, tableName: String) // Uncache the logicalPlan. sqlContext.cacheManager.tryUncacheQuery(df, blocking = true) // Cache it again. - sqlContext.cacheManager.cacheQuery(df, Some(tableName)) + sqlContext.cacheManager.cacheQuery(df, Some(tableIdent.table)) } Seq.empty[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala index eb15a1609f1d0..64e94056f209a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -22,6 +22,7 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.{DataFrame, SaveMode} @@ -32,8 +33,7 @@ import org.apache.spark.sql.{DataFrame, SaveMode} * convenient to use tuples rather than special case classes when writing test cases/suites. * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. */ -private[sql] trait ParquetTest extends SQLTestUtils { - +private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => /** * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index fa01823e9417c..4c11acdab9ec0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -18,13 +18,15 @@ package org.apache.spark.sql.test import java.io.File +import java.util.UUID import scala.util.Try +import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils -trait SQLTestUtils { +trait SQLTestUtils { this: SparkFunSuite => def sqlContext: SQLContext protected def configuration = sqlContext.sparkContext.hadoopConfiguration @@ -87,4 +89,29 @@ trait SQLTestUtils { } } } + + /** + * Creates a temporary database and switches current database to it before executing `f`. This + * database is dropped after `f` returns. + */ + protected def withTempDatabase(f: String => Unit): Unit = { + val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" + + try { + sqlContext.sql(s"CREATE DATABASE $dbName") + } catch { case cause: Throwable => + fail("Failed to create temporary database", cause) + } + + try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + } + + /** + * Activates database `db` before executing `f`, then switches back to `default` database after + * `f` returns. + */ + protected def activateDatabase(db: String)(f: => Unit): Unit = { + sqlContext.sql(s"USE $db") + try f finally sqlContext.sql(s"USE default") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 1b8edefef4093..110f51a305861 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -40,7 +40,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql.catalyst.ParserDialect +import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} @@ -267,7 +267,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { * @since 1.3.0 */ def refreshTable(tableName: String): Unit = { - catalog.refreshTable(catalog.client.currentDatabase, tableName) + val tableIdent = TableIdentifier(tableName).withDatabase(catalog.client.currentDatabase) + catalog.refreshTable(tableIdent) } protected[hive] def invalidateTable(tableName: String): Unit = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 262923531216f..9c707a7a2eca1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -29,13 +29,13 @@ import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} import org.apache.spark.sql.execution.datasources import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} import org.apache.spark.sql.hive.client._ @@ -43,7 +43,6 @@ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} - private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: HiveContext) extends Catalog with Logging { @@ -115,7 +114,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader) } - override def refreshTable(databaseName: String, tableName: String): Unit = { + override def refreshTable(tableIdent: TableIdentifier): Unit = { // refreshTable does not eagerly reload the cache. It just invalidate the cache. // Next time when we use the table, it will be populated in the cache. // Since we also cache ParquetRelations converted from Hive Parquet tables and @@ -124,7 +123,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // it is better at here to invalidate the cache to avoid confusing waring logs from the // cache loader (e.g. cannot find data source provider, which is only defined for // data source table.). - invalidateTable(databaseName, tableName) + invalidateTable(tableIdent.database.getOrElse(client.currentDatabase), tableIdent.table) } def invalidateTable(databaseName: String, tableName: String): Unit = { @@ -144,7 +143,27 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive provider: String, options: Map[String, String], isExternal: Boolean): Unit = { - val (dbName, tblName) = processDatabaseAndTableName(client.currentDatabase, tableName) + createDataSourceTable( + new SqlParser().parseTableIdentifier(tableName), + userSpecifiedSchema, + partitionColumns, + provider, + options, + isExternal) + } + + private def createDataSourceTable( + tableIdent: TableIdentifier, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + provider: String, + options: Map[String, String], + isExternal: Boolean): Unit = { + val (dbName, tblName) = { + val database = tableIdent.database.getOrElse(client.currentDatabase) + processDatabaseAndTableName(database, tableIdent.table) + } + val tableProperties = new scala.collection.mutable.HashMap[String, String] tableProperties.put("spark.sql.sources.provider", provider) @@ -177,7 +196,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive // partitions when we load the table. However, if there are specified partition columns, // we simplily ignore them and provide a warning message.. logWarning( - s"The schema and partitions of table $tableName will be inferred when it is loaded. " + + s"The schema and partitions of table $tableIdent will be inferred when it is loaded. " + s"Specified partition columns (${partitionColumns.mkString(",")}) will be ignored.") } Seq.empty[HiveColumn] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala new file mode 100644 index 0000000000000..73852f13ad20d --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode} + +class MultiDatabaseSuite extends QueryTest with SQLTestUtils { + override val sqlContext: SQLContext = TestHive + + import sqlContext.sql + + private val df = sqlContext.range(10).coalesce(1) + + test(s"saveAsTable() to non-default database - with USE - Overwrite") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + } + } + + test(s"saveAsTable() to non-default database - without USE - Overwrite") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df) + } + } + + test(s"saveAsTable() to non-default database - with USE - Append") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + df.write.mode(SaveMode.Append).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + checkAnswer(sqlContext.table("t"), df.unionAll(df)) + } + + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test(s"saveAsTable() to non-default database - without USE - Append") { + withTempDatabase { db => + df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") + df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") + assert(sqlContext.tableNames(db).contains("t")) + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test(s"insertInto() non-default database - with USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + } + + test(s"insertInto() non-default database - without USE") { + withTempDatabase { db => + activateDatabase(db) { + df.write.mode(SaveMode.Overwrite).saveAsTable("t") + assert(sqlContext.tableNames().contains("t")) + } + + assert(sqlContext.tableNames(db).contains("t")) + + df.write.insertInto(s"$db.t") + checkAnswer(sqlContext.table(s"$db.t"), df.unionAll(df)) + } + } + + test("Looks up tables in non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql("CREATE TABLE t (key INT)") + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + } + + checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + } + } + + test("Drops a table in a non-default database") { + withTempDatabase { db => + activateDatabase(db) { + sql(s"CREATE TABLE t (key INT)") + assert(sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(sqlContext.tableNames(db).contains("t")) + + activateDatabase(db) { + sql(s"DROP TABLE t") + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames("default").contains("t")) + } + + assert(!sqlContext.tableNames().contains("t")) + assert(!sqlContext.tableNames(db).contains("t")) + } + } + + test("Refreshes a table in a non-default database") { + import org.apache.spark.sql.functions.lit + + withTempDatabase { db => + withTempPath { dir => + val path = dir.getCanonicalPath + + activateDatabase(db) { + sql( + s"""CREATE EXTERNAL TABLE t (id BIGINT) + |PARTITIONED BY (p INT) + |STORED AS PARQUET + |LOCATION '$path' + """.stripMargin) + + checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + + df.write.parquet(s"$path/p=1") + sql("ALTER TABLE t ADD PARTITION (p=1)") + sql("REFRESH TABLE t") + checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + } + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 9d76d6503a3e6..145965388da01 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -22,14 +22,15 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ +import org.apache.spark.sql.test.SQLTestUtils -private[sql] trait OrcTest extends SQLTestUtils { +private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite => lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive - import sqlContext.sparkContext import sqlContext.implicits._ + import sqlContext.sparkContext /** * Writes `data` to a Orc file, which is then passed to `f` and will be deleted after `f` From 622838165756e9669cbf7af13eccbc719638f40b Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Mon, 27 Jul 2015 08:02:40 -0500 Subject: [PATCH 079/219] [SPARK-8405] [DOC] Add how to view logs on Web UI when yarn log aggregation is enabled Some users may not be aware that the logs are available on Web UI even if Yarn log aggregation is enabled. Update the doc to make this clear and what need to be configured. Author: Carson Wang Closes #7463 from carsonwang/YarnLogDoc and squashes the following commits: 274c054 [Carson Wang] Minor text fix 74df3a1 [Carson Wang] address comments 5a95046 [Carson Wang] Update the text in the doc e5775c1 [Carson Wang] Update doc about how to view the logs on Web UI when yarn log aggregation is enabled --- docs/running-on-yarn.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index de22ab557cacf..cac08a91b97d9 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -68,9 +68,9 @@ In YARN terminology, executors and application masters run inside "containers". yarn logs -applicationId -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). The logs are also available on the Spark Web UI under the Executors Tab. You need to have both the Spark history server and the MapReduce history server running and configure `yarn.log.server.url` in `yarn-site.xml` properly. The log URL on the Spark history server UI will redirect you to the MapReduce history server to show the aggregated logs. -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. The logs are also available on the Spark Web UI under the Executors Tab and doesn't require running the MapReduce history server. To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` From aa19c696e25ebb07fd3df110cfcbcc69954ce335 Mon Sep 17 00:00:00 2001 From: Rene Treffer Date: Mon, 27 Jul 2015 23:29:40 +0800 Subject: [PATCH 080/219] [SPARK-4176] [SQL] Supports decimal types with precision > 18 in Parquet This PR is based on #6796 authored by rtreffer. To support large decimal precisions (> 18), we do the following things in this PR: 1. Making `CatalystSchemaConverter` support large decimal precision Decimal types with large precision are always converted to fixed-length byte array. 2. Making `CatalystRowConverter` support reading decimal values with large precision When the precision is > 18, constructs `Decimal` values with an unscaled `BigInteger` rather than an unscaled `Long`. 3. Making `RowWriteSupport` support writing decimal values with large precision In this PR we always write decimals as fixed-length byte array, because Parquet write path hasn't been refactored to conform Parquet format spec (see SPARK-6774 & SPARK-8848). Two follow-up tasks should be done in future PRs: - [ ] Writing decimals as `INT32`, `INT64` when possible while fixing SPARK-8848 - [ ] Adding compatibility tests as part of SPARK-5463 Author: Cheng Lian Closes #7455 from liancheng/spark-4176 and squashes the following commits: a543d10 [Cheng Lian] Fixes errors introduced while rebasing 9e31cdf [Cheng Lian] Supports decimals with precision > 18 for Parquet --- .../sql/parquet/CatalystRowConverter.scala | 25 +++++--- .../sql/parquet/CatalystSchemaConverter.scala | 46 +++++++------ .../sql/parquet/ParquetTableSupport.scala | 64 +++++++++++++------ .../spark/sql/parquet/ParquetIOSuite.scala | 10 +-- 4 files changed, 85 insertions(+), 60 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index b5e4263008f56..e00bd90edb3dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder import scala.collection.JavaConversions._ @@ -263,17 +264,23 @@ private[parquet] class CatalystRowConverter( val scale = decimalType.scale val bytes = value.getBytes - var unscaled = 0L - var i = 0 + if (precision <= 8) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + var unscaled = 0L + var i = 0 - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } - val bits = 8 * bytes.length - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - Decimal(unscaled, precision, scale) + val bits = 8 * bytes.length + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + Decimal(unscaled, precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(new BigDecimal(new BigInteger(bytes), scale), precision, scale) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index e9ef01e2dba1b..d43ca95b4eea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -387,24 +387,18 @@ private[parquet] class CatalystSchemaConverter( // ===================================== // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and - // always store decimals in fixed-length byte arrays. - case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => + // always store decimals in fixed-length byte arrays. To keep compatibility with these older + // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated + // by `DECIMAL`. + case DecimalType.Fixed(precision, scale) if !followParquetFormatSpec => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) .as(DECIMAL) .precision(precision) .scale(scale) - .length(minBytesForPrecision(precision)) + .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - case dec @ DecimalType() if !followParquetFormatSpec => - throw new AnalysisException( - s"Data type $dec is not supported. " + - s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," + - "decimal precision and scale must be specified, " + - "and precision must be less than or equal to 18.") - // ===================================== // Decimals (follow Parquet format spec) // ===================================== @@ -436,7 +430,7 @@ private[parquet] class CatalystSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(minBytesForPrecision(precision)) + .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) // =================================================== @@ -548,15 +542,6 @@ private[parquet] class CatalystSchemaConverter( Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes .asInstanceOf[Int] } - - // Min byte counts needed to store decimals with various precisions - private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision => - var numBytes = 1 - while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { - numBytes += 1 - } - numBytes - } } @@ -580,4 +565,23 @@ private[parquet] object CatalystSchemaConverter { throw new AnalysisException(message) } } + + private def computeMinBytesForPrecision(precision : Int) : Int = { + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } + + private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision) + + // Returns the minimum number of bytes needed to store a decimal with a given `precision`. + def minBytesForPrecision(precision : Int) : Int = { + if (precision < MIN_BYTES_FOR_PRECISION.length) { + MIN_BYTES_FOR_PRECISION(precision) + } else { + computeMinBytesForPrecision(precision) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index fc9f61a636768..78ecfad1d57c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.math.BigInteger import java.nio.{ByteBuffer, ByteOrder} import java.util.{HashMap => JHashMap} @@ -114,11 +115,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case d: DecimalType => - if (d.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(value.asInstanceOf[Decimal], d.precision) + case DecimalType.Fixed(precision, _) => + writeDecimal(value.asInstanceOf[Decimal], precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -199,20 +197,47 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.endGroup() } - // Scratch array used to write decimals as fixed-length binary - private[this] val scratchBytes = new Array[Byte](8) + // Scratch array used to write decimals as fixed-length byte array + private[this] var reusableDecimalBytes = new Array[Byte](16) private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { - val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) - val unscaledLong = decimal.toUnscaledLong - var i = 0 - var shift = 8 * (numBytes - 1) - while (i < numBytes) { - scratchBytes(i) = (unscaledLong >> shift).toByte - i += 1 - shift -= 8 + val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision) + + def longToBinary(unscaled: Long): Binary = { + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + reusableDecimalBytes(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) } - writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) + + def bigIntegerToBinary(unscaled: BigInteger): Binary = { + unscaled.toByteArray match { + case bytes if bytes.length == numBytes => + Binary.fromByteArray(bytes) + + case bytes if bytes.length <= reusableDecimalBytes.length => + val signedByte = (if (bytes.head < 0) -1 else 0).toByte + java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte) + System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length) + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) + + case bytes => + reusableDecimalBytes = new Array[Byte](bytes.length) + bigIntegerToBinary(unscaled) + } + } + + val binary = if (numBytes <= 8) { + longToBinary(decimal.toUnscaledLong) + } else { + bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue()) + } + + writer.addBinary(binary) } // array used to write Timestamp as Int96 (fixed-length binary) @@ -268,11 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) case BinaryType => writer.addBinary(Binary.fromByteArray(record.getBinary(index))) - case d: DecimalType => - if (d.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(record.getDecimal(index), d.precision) + case DecimalType.Fixed(precision, _) => + writeDecimal(record.getDecimal(index), precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index b5314a3dd92e5..b415da5b8c136 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -106,21 +106,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest { // Parquet doesn't allow column names with spaces, have to add an alias here .select($"_1" cast decimal as "dec") - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } - - // Decimals with precision above 18 are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() - } - } } test("date type") { From 90006f3c51f8cf9535854246050e27bb76b043f0 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Tue, 28 Jul 2015 01:33:31 +0900 Subject: [PATCH 081/219] Pregel example type fix Pregel example to express single source shortest path from https://spark.apache.org/docs/latest/graphx-programming-guide.html#pregel-api does not work due to incorrect type. The reason is that `GraphGenerators.logNormalGraph` returns the graph with `Long` vertices. Fixing `val graph: Graph[Int, Double]` to `val graph: Graph[Long, Double]`. Author: Alexander Ulanov Closes #7695 from avulanov/SPARK-9380-pregel-doc and squashes the following commits: c269429 [Alexander Ulanov] Pregel example type fix --- docs/graphx-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 3f10cb2dc3d2a..99f8c827f767f 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -800,7 +800,7 @@ import org.apache.spark.graphx._ // Import random graph generation library import org.apache.spark.graphx.util.GraphGenerators // A graph with edge attributes containing distances -val graph: Graph[Int, Double] = +val graph: Graph[Long, Double] = GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble) val sourceId: VertexId = 42 // The ultimate source // Initialize the graph such that all vertices except the root have distance infinity. From ecad9d4346ec158746e61aebdf1590215a77f369 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 27 Jul 2015 09:34:49 -0700 Subject: [PATCH 082/219] [SPARK-9364] Fix array out of bounds and use-after-free bugs in UnsafeExternalSorter This patch fixes two bugs in UnsafeExternalSorter and UnsafeExternalRowSorter: - UnsafeExternalSorter does not properly update freeSpaceInCurrentPage, which can cause it to write past the end of memory pages and trigger segfaults. - UnsafeExternalRowSorter has a use-after-free bug when returning the last row from an iterator. Author: Josh Rosen Closes #7680 from JoshRosen/SPARK-9364 and squashes the following commits: 590f311 [Josh Rosen] null out row f4cf91d [Josh Rosen] Fix use-after-free bug in UnsafeExternalRowSorter. 8abcf82 [Josh Rosen] Properly decrement freeSpaceInCurrentPage in UnsafeExternalSorter --- .../unsafe/sort/UnsafeExternalSorter.java | 7 ++++++- .../sort/UnsafeExternalSorterSuite.java | 19 +++++++++++++++++++ .../execution/UnsafeExternalRowSorter.java | 9 ++++++--- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 4d6731ee60af3..80b03d7e99e2b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -150,6 +150,11 @@ private long getMemoryUsage() { return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); } + @VisibleForTesting + public int getNumberOfAllocatedPages() { + return allocatedPages.size(); + } + public long freeMemory() { long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { @@ -257,7 +262,7 @@ public void insertRecord( currentPagePosition, lengthInBytes); currentPagePosition += lengthInBytes; - + freeSpaceInCurrentPage -= totalSpaceRequired; sorter.insertRecord(recordAddress, prefix); } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index ea8755e21eb68..0e391b751226d 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -199,4 +199,23 @@ public void testSortingEmptyArrays() throws Exception { } } + @Test + public void testFillingPage() throws Exception { + final UnsafeExternalSorter sorter = new UnsafeExternalSorter( + memoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + 1024, + new SparkConf()); + + byte[] record = new byte[16]; + while (sorter.getNumberOfAllocatedPages() < 2) { + sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); + } + sorter.freeMemory(); + } + } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index be4ff400c4754..4c3f2c6557140 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -124,7 +124,7 @@ Iterator sort() throws IOException { return new AbstractScalaRowIterator() { private final int numFields = schema.length(); - private final UnsafeRow row = new UnsafeRow(); + private UnsafeRow row = new UnsafeRow(); @Override public boolean hasNext() { @@ -141,10 +141,13 @@ public InternalRow next() { numFields, sortedIterator.getRecordLength()); if (!hasNext()) { - row.copy(); // so that we don't have dangling pointers to freed page + UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page + row = null; // so that we don't keep references to the base object cleanupResources(); + return copy; + } else { + return row; } - return row; } catch (IOException e) { cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack From c0b7df68f81c2a2a9c1065009fe75c278fa30499 Mon Sep 17 00:00:00 2001 From: Ryan Williams Date: Mon, 27 Jul 2015 12:54:08 -0500 Subject: [PATCH 083/219] [SPARK-9366] use task's stageAttemptId in TaskEnd event Author: Ryan Williams Closes #7681 from ryan-williams/task-stage-attempt and squashes the following commits: d6d5f0f [Ryan Williams] use task's stageAttemptId in TaskEnd event --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 552dabcfa5139..b6a833bbb0833 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -927,7 +927,7 @@ class DAGScheduler( // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { - val attemptId = stageIdToStage.get(task.stageId).map(_.latestInfo.attemptId).getOrElse(-1) + val attemptId = task.stageAttemptId listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) } From e2f38167f8b5678ac45794eacb9c7bb9b951af82 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 27 Jul 2015 11:02:16 -0700 Subject: [PATCH 084/219] [SPARK-9376] [SQL] use a seed in RandomDataGeneratorSuite Make this test deterministic, i.e. make sure this test can be passed no matter how many times we run it. The origin implementation uses a random seed and gives a chance that we may break the null check assertion `assert(Iterator.fill(100)(generator()).contains(null))`. Author: Wenchen Fan Closes #7691 from cloud-fan/seed and squashes the following commits: eae7281 [Wenchen Fan] use a seed in RandomDataGeneratorSuite --- .../scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala index 677ba0a18040c..cccac7efa09e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala @@ -32,7 +32,7 @@ class RandomDataGeneratorSuite extends SparkFunSuite { */ def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) - val generator = RandomDataGenerator.forType(dataType, nullable).getOrElse { + val generator = RandomDataGenerator.forType(dataType, nullable, Some(33)).getOrElse { fail(s"Random data generator was not defined for $dataType") } if (nullable) { From 1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42 Mon Sep 17 00:00:00 2001 From: George Dittmar Date: Mon, 27 Jul 2015 11:16:33 -0700 Subject: [PATCH 085/219] [SPARK-7423] [MLLIB] Modify ClassificationModel and Probabalistic model to use Vector.argmax Use Vector.argmax call instead of converting to dense vector before calculating predictions. Author: George Dittmar Closes #7670 from GeorgeDittmar/sprk-7423 and squashes the following commits: e796747 [George Dittmar] Changing ClassificationModel and ProbabilisticClassificationModel to use Vector.argmax instead of converting to DenseVector --- .../scala/org/apache/spark/ml/classification/Classifier.scala | 2 +- .../spark/ml/classification/ProbabilisticClassifier.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 85c097bc64a4f..581d8fa7749be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -156,5 +156,5 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * This may be overridden to support thresholds which favor particular labels. * @return predicted label */ - protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax + protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 38e832372698c..dad451108626d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -173,5 +173,5 @@ private[spark] abstract class ProbabilisticClassificationModel[ * This may be overridden to support thresholds which favor particular labels. * @return predicted label */ - protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax + protected def probability2prediction(probability: Vector): Double = probability.argmax } From dd9ae7945ab65d353ed2b113e0c1a00a0533ffd6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 27 Jul 2015 11:23:29 -0700 Subject: [PATCH 086/219] [SPARK-9351] [SQL] remove literals from grouping expressions in Aggregate literals in grouping expressions have no effect at all, only make our grouping key bigger, so we should remove them in Optimizer. I also make old and new aggregation code consistent about literals in grouping here. In old aggregation, actually literals in grouping are already removed but new aggregation is not. So I explicitly make it a rule in Optimizer. Author: Wenchen Fan Closes #7583 from cloud-fan/minor and squashes the following commits: 471adff [Wenchen Fan] add test 0839925 [Wenchen Fan] use transformDown when rewrite final result expressions --- .../sql/catalyst/optimizer/Optimizer.scala | 17 +++++++++-- .../sql/catalyst/planning/patterns.scala | 4 +-- ...ite.scala => AggregateOptimizeSuite.scala} | 19 ++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 29 +++++++++++++++---- 4 files changed, 57 insertions(+), 12 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{ReplaceDistinctWithAggregateSuite.scala => AggregateOptimizeSuite.scala} (72%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b59f800e7cc0f..813c62009666c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -36,8 +36,9 @@ object DefaultOptimizer extends Optimizer { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: - Batch("Distinct", FixedPoint(100), - ReplaceDistinctWithAggregate) :: + Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down SetOperationPushDown, @@ -799,3 +800,15 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { case Distinct(child) => Aggregate(child.output, child.output, child) } } + +/** + * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result + * but only makes the grouping key bigger. + */ +object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(grouping, _, _) => + val newGrouping = grouping.filter(!_.foldable) + a.copy(groupingExpressions = newGrouping) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1e7b2a536ac12..b9ca712c1ee1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -144,14 +144,14 @@ object PartialAggregation { // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = - groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + groupingExpressions.map { case n: NamedExpression => (n, n) case other => (other, Alias(other, "PartialGroup")()) } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala similarity index 72% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index df29a62ff0e15..2d080b95b1292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceDistinctWithAggregateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -19,14 +19,17 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -class ReplaceDistinctWithAggregateSuite extends PlanTest { +class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("ProjectCollapsing", Once, ReplaceDistinctWithAggregate) :: Nil + val batches = Batch("Aggregate", FixedPoint(100), + ReplaceDistinctWithAggregate, + RemoveLiteralFromGroupExpressions) :: Nil } test("replace distinct with aggregate") { @@ -39,4 +42,16 @@ class ReplaceDistinctWithAggregateSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("remove literals in grouping expression") { + val input = LocalRelation('a.int, 'b.int) + + val query = + input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b)) + val optimized = Optimize.execute(query) + + val correctAnswer = input.groupBy('a)(sum('b)) + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8cef0b39f87dc..358e319476e83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -463,12 +463,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("literal in agg grouping expressions") { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + def literalInAggTest(): Unit = { + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) + } + + literalInAggTest() + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + literalInAggTest() + } } test("aggregates with nulls") { From 75438422c2cd90dca53f84879cddecfc2ee0e957 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 27 Jul 2015 11:28:22 -0700 Subject: [PATCH 087/219] [SPARK-9369][SQL] Support IntervalType in UnsafeRow Author: Wenchen Fan Closes #7688 from cloud-fan/interval and squashes the following commits: 5b36b17 [Wenchen Fan] fix codegen a99ed50 [Wenchen Fan] address comment 9e6d319 [Wenchen Fan] Support IntervalType in UnsafeRow --- .../sql/catalyst/expressions/UnsafeRow.java | 23 ++++++++++++++----- .../expressions/UnsafeRowWriters.java | 19 ++++++++++++++- .../spark/sql/catalyst/InternalRow.scala | 4 +++- .../catalyst/expressions/BoundAttribute.scala | 1 + .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 7 +++--- .../codegen/GenerateUnsafeProjection.scala | 6 +++++ .../expressions/ExpressionEvalHelper.scala | 2 -- 8 files changed, 50 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 0fb33dd5a15a0..fb084dd13b620 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -29,6 +29,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.types.Interval; import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.types.DataTypes.*; @@ -90,7 +91,8 @@ public static int calculateBitSetWidthInBytes(int numFields) { final Set _readableFieldTypes = new HashSet<>( Arrays.asList(new DataType[]{ StringType, - BinaryType + BinaryType, + IntervalType })); _readableFieldTypes.addAll(settableFieldTypes); readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); @@ -332,11 +334,6 @@ public UTF8String getUTF8String(int ordinal) { return isNullAt(ordinal) ? null : UTF8String.fromBytes(getBinary(ordinal)); } - @Override - public String getString(int ordinal) { - return getUTF8String(ordinal).toString(); - } - @Override public byte[] getBinary(int ordinal) { if (isNullAt(ordinal)) { @@ -358,6 +355,20 @@ public byte[] getBinary(int ordinal) { } } + @Override + public Interval getInterval(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long microseconds = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + return new Interval(months, microseconds); + } + } + @Override public UnsafeRow getStruct(int ordinal, int numFields) { if (isNullAt(ordinal)) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 87521d1f23c99..0ba31d3b9b743 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -20,6 +20,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; +import org.apache.spark.unsafe.types.Interval; import org.apache.spark.unsafe.types.UTF8String; /** @@ -54,7 +55,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String in } } - /** Writer for bianry (byte array) type. */ + /** Writer for binary (byte array) type. */ public static class BinaryWriter { public static int getSize(byte[] input) { @@ -80,4 +81,20 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) } } + /** Writer for interval type. */ + public static class IntervalWriter { + + public static int write(UnsafeRow target, int ordinal, int cursor, Interval input) { + final long offset = target.getBaseOffset() + cursor; + + // Write the months and microseconds fields of Interval to the variable length portion. + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, input.months); + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, input.microseconds); + + // Set the fixed length portion. + target.setLong(ordinal, ((long) cursor) << 32); + return 16; + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index ad3977281d1a9..9a11de3840ce2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{Interval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -60,6 +60,8 @@ abstract class InternalRow extends Serializable { def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType) + // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 6b5c450e3fb0a..41a877f214e55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -48,6 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) + case IntervalType => input.getInterval(ordinal) case t: StructType => input.getStruct(ordinal, t.size) case dataType => input.get(ordinal, dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index e208262da96dc..bd8b0177eb00e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -630,7 +630,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"$evPrim = org.apache.spark.unsafe.types.Interval.fromString($c.toString());" + s"$evPrim = Interval.fromString($c.toString());" } private[this] def decimalToTimestampCode(d: String): String = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2a1e288cb8377..2f02c90b1d5b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -79,7 +79,6 @@ class CodeGenContext { mutableStates += ((javaType, variableName, initCode)) } - final val intervalType: String = classOf[Interval].getName final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -109,6 +108,7 @@ class CodeGenContext { case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" case StringType => s"$row.getUTF8String($ordinal)" case BinaryType => s"$row.getBinary($ordinal)" + case IntervalType => s"$row.getInterval($ordinal)" case t: StructType => s"$row.getStruct($ordinal, ${t.size})" case _ => s"($jt)$row.get($ordinal)" } @@ -150,7 +150,7 @@ class CodeGenContext { case dt: DecimalType => "Decimal" case BinaryType => "byte[]" case StringType => "UTF8String" - case IntervalType => intervalType + case IntervalType => "Interval" case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" @@ -292,7 +292,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[InternalRow].getName, classOf[UnsafeRow].getName, classOf[UTF8String].getName, - classOf[Decimal].getName + classOf[Decimal].getName, + classOf[Interval].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index afd0d9cfa1ddd..9d2161947b351 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -33,10 +33,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName + private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case t: AtomicType if !t.isInstanceOf[DecimalType] => true + case _: IntervalType => true case NullType => true case _ => false } @@ -68,6 +70,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" case BinaryType => s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" + case IntervalType => + s" + (${exprs(i).isNull} ? 0 : 16)" case _ => "" } }.mkString("") @@ -80,6 +84,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case BinaryType => s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" + case IntervalType => + s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 8b0f90cf3a623..ab0cdc857c80e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -78,8 +78,6 @@ trait ExpressionEvalHelper { generator } catch { case e: Throwable => - val ctx = new CodeGenContext - val evaluated = expression.gen(ctx) fail( s""" |Code generation of $expression failed: From 85a50a6352b72c4619d010e29e3a76774dbc0c71 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 12:25:34 -0700 Subject: [PATCH 088/219] [HOTFIX] Disable pylint since it is failing master. --- dev/lint-python | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index e02dff220eb87..53bccc1fab535 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -96,19 +96,19 @@ fi rm "$PEP8_REPORT_PATH" -for to_be_checked in "$PATHS_TO_CHECK" -do - pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" -done - -if [ "${PIPESTATUS[0]}" -ne 0 ]; then - lint_status=1 - echo "Pylint checks failed." - cat "$PYLINT_REPORT_PATH" -else - echo "Pylint checks passed." -fi - -rm "$PYLINT_REPORT_PATH" +# for to_be_checked in "$PATHS_TO_CHECK" +# do +# pylint --rcfile="$SPARK_ROOT_DIR/pylintrc" $to_be_checked >> "$PYLINT_REPORT_PATH" +# done + +# if [ "${PIPESTATUS[0]}" -ne 0 ]; then +# lint_status=1 +# echo "Pylint checks failed." +# cat "$PYLINT_REPORT_PATH" +# else +# echo "Pylint checks passed." +# fi + +# rm "$PYLINT_REPORT_PATH" exit "$lint_status" From fa84e4a7ba6eab476487185178a556e4f04e4199 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 13:21:04 -0700 Subject: [PATCH 089/219] Closes #7690 since it has been merged into branch-1.4. From 55946e76fd136958081f073c0c5e3ff8563d505b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 27 Jul 2015 13:26:57 -0700 Subject: [PATCH 090/219] [SPARK-9349] [SQL] UDAF cleanup https://issues.apache.org/jira/browse/SPARK-9349 With this PR, we only expose `UserDefinedAggregateFunction` (an abstract class) and `MutableAggregationBuffer` (an interface). Other internal wrappers and helper classes are moved to `org.apache.spark.sql.execution.aggregate` and marked as `private[sql]`. Author: Yin Huai Closes #7687 from yhuai/UDAF-cleanup and squashes the following commits: db36542 [Yin Huai] Add comments to UDAF examples. ae17f66 [Yin Huai] Address comments. 9c9fa5f [Yin Huai] UDAF cleanup. --- .../apache/spark/sql/UDAFRegistration.scala | 3 +- .../aggregate/udaf.scala | 122 +++++------------- .../apache/spark/sql/expressions/udaf.scala | 101 +++++++++++++++ .../spark/sql/hive/aggregate/MyDoubleAvg.java | 34 ++++- .../spark/sql/hive/aggregate/MyDoubleSum.java | 28 +++- 5 files changed, 187 insertions(+), 101 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/{expressions => execution}/aggregate/udaf.scala (67%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala index 5b872f5e3eecd..0d4e30f29255e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDAFRegistration.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Expression} -import org.apache.spark.sql.expressions.aggregate.{ScalaUDAF, UserDefinedAggregateFunction} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction class UDAFRegistration private[sql] (sqlContext: SQLContext) extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala similarity index 67% rename from sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 4ada9eca7a035..073c45ae2f9f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -15,87 +15,29 @@ * limitations under the License. */ -package org.apache.spark.sql.expressions.aggregate +package org.apache.spark.sql.execution.aggregate import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 -import org.apache.spark.sql.types._ -import org.apache.spark.sql.Row +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.types.{Metadata, StructField, StructType, DataType} /** - * The abstract class for implementing user-defined aggregate function. + * A Mutable [[Row]] representing an mutable aggregation buffer. */ -abstract class UserDefinedAggregateFunction extends Serializable { - - /** - * A [[StructType]] represents data types of input arguments of this aggregate function. - * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments - * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like - * - * ``` - * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) - * ``` - * - * The name of a field of this [[StructType]] is only used to identify the corresponding - * input argument. Users can choose names to identify the input arguments. - */ - def inputSchema: StructType - - /** - * A [[StructType]] represents data types of values in the aggregation buffer. - * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values - * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], - * the returned [[StructType]] will look like - * - * ``` - * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) - * ``` - * - * The name of a field of this [[StructType]] is only used to identify the corresponding - * buffer value. Users can choose names to identify the input arguments. - */ - def bufferSchema: StructType - - /** - * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. - */ - def returnDataType: DataType - - /** Indicates if this function is deterministic. */ - def deterministic: Boolean - - /** - * Initializes the given aggregation buffer. Initial values set by this method should satisfy - * the condition that when merging two buffers with initial values, the new buffer should - * still store initial values. - */ - def initialize(buffer: MutableAggregationBuffer): Unit - - /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ - def update(buffer: MutableAggregationBuffer, input: Row): Unit - - /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */ - def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit - - /** - * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given - * aggregation buffer. - */ - def evaluate(buffer: Row): Any -} - -private[sql] abstract class AggregationBuffer( +private[sql] class MutableAggregationBufferImpl ( + schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], - bufferOffset: Int) - extends Row { - - override def length: Int = toCatalystConverters.length + bufferOffset: Int, + var underlyingBuffer: MutableRow) + extends MutableAggregationBuffer { - protected val offsets: Array[Int] = { + private[this] val offsets: Array[Int] = { val newOffsets = new Array[Int](length) var i = 0 while (i < newOffsets.length) { @@ -104,18 +46,8 @@ private[sql] abstract class AggregationBuffer( } newOffsets } -} -/** - * A Mutable [[Row]] representing an mutable aggregation buffer. - */ -class MutableAggregationBuffer private[sql] ( - schema: StructType, - toCatalystConverters: Array[Any => Any], - toScalaConverters: Array[Any => Any], - bufferOffset: Int, - var underlyingBuffer: MutableRow) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { if (i >= length || i < 0) { @@ -133,8 +65,8 @@ class MutableAggregationBuffer private[sql] ( underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) } - override def copy(): MutableAggregationBuffer = { - new MutableAggregationBuffer( + override def copy(): MutableAggregationBufferImpl = { + new MutableAggregationBufferImpl( schema, toCatalystConverters, toScalaConverters, @@ -146,13 +78,25 @@ class MutableAggregationBuffer private[sql] ( /** * A [[Row]] representing an immutable aggregation buffer. */ -class InputAggregationBuffer private[sql] ( +private[sql] class InputAggregationBuffer private[sql] ( schema: StructType, toCatalystConverters: Array[Any => Any], toScalaConverters: Array[Any => Any], bufferOffset: Int, var underlyingInputBuffer: InternalRow) - extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + extends Row { + + private[this] val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } + + override def length: Int = toCatalystConverters.length override def get(i: Int): Any = { if (i >= length || i < 0) { @@ -179,7 +123,7 @@ class InputAggregationBuffer private[sql] ( * @param children * @param udaf */ -case class ScalaUDAF( +private[sql] case class ScalaUDAF( children: Seq[Expression], udaf: UserDefinedAggregateFunction) extends AggregateFunction2 with Logging { @@ -243,8 +187,8 @@ case class ScalaUDAF( bufferOffset, null) - lazy val mutableAggregateBuffer: MutableAggregationBuffer = - new MutableAggregationBuffer( + lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = + new MutableAggregationBufferImpl( bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala new file mode 100644 index 0000000000000..278dd438fab4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * The abstract class for implementing user-defined aggregate functions. + */ +@Experimental +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * new StructType() + * .add("doubleInput", DoubleType) + * .add("longInput", LongType) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def returnDataType: DataType + + /** Indicates if this function is deterministic. */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer. Initial values set by this method should satisfy + * the condition that when merging two buffers with initial values, the new buffer + * still store initial values. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any +} + +/** + * :: Experimental :: + * A [[Row]] representing an mutable aggregation buffer. + */ +@Experimental +trait MutableAggregationBuffer extends Row { + + /** Update the ith value of this buffer. */ + def update(i: Int, value: Any): Unit +} diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java index 5c9d0e97a99c6..a2247e3da1554 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -21,13 +21,18 @@ import java.util.List; import org.apache.spark.sql.Row; -import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +/** + * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a + * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum + * of the average value of input values and 100.0. + */ public class MyDoubleAvg extends UserDefinedAggregateFunction { private StructType _inputDataType; @@ -37,10 +42,13 @@ public class MyDoubleAvg extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleAvg() { - List inputfields = new ArrayList(); - inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputfields); + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); + // The buffer has two values, bufferSum for storing the current sum and + // bufferCount for storing the number of non-null input values that have been contribuetd + // to the current sum. List bufferFields = new ArrayList(); bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); @@ -66,16 +74,23 @@ public MyDoubleAvg() { } @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. buffer.update(0, null); + // The initial value of the count is 0. buffer.update(1, 0L); } @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. if (!input.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer and set the bufferCount to 1. if (buffer.isNullAt(0)) { buffer.update(0, input.getDouble(0)); buffer.update(1, 1L); } else { + // Otherwise, update the bufferSum and increment bufferCount. Double newValue = input.getDouble(0) + buffer.getDouble(0); buffer.update(0, newValue); buffer.update(1, buffer.getLong(1) + 1L); @@ -84,11 +99,16 @@ public MyDoubleAvg() { } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's sum value is not null. if (!buffer2.isNullAt(0)) { if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. buffer1.update(0, buffer2.getDouble(0)); buffer1.update(1, buffer2.getLong(1)); } else { + // Otherwise, we update the bufferSum and bufferCount. Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); buffer1.update(0, newValue); buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); @@ -98,10 +118,12 @@ public MyDoubleAvg() { @Override public Object evaluate(Row buffer) { if (buffer.isNullAt(0)) { + // If the bufferSum is still null, we return null because this function has not got + // any input row. return null; } else { + // Otherwise, we calculate the special average value. return buffer.getDouble(0) / buffer.getLong(1) + 100.0; } } } - diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java index 1d4587a27c787..da29e24d267dd 100644 --- a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -20,14 +20,18 @@ import java.util.ArrayList; import java.util.List; -import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; -import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.expressions.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.Row; +/** + * An example {@link UserDefinedAggregateFunction} to calculate the sum of a + * {@link org.apache.spark.sql.types.DoubleType} column. + */ public class MyDoubleSum extends UserDefinedAggregateFunction { private StructType _inputDataType; @@ -37,9 +41,9 @@ public class MyDoubleSum extends UserDefinedAggregateFunction { private DataType _returnDataType; public MyDoubleSum() { - List inputfields = new ArrayList(); - inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); - _inputDataType = DataTypes.createStructType(inputfields); + List inputFields = new ArrayList(); + inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputFields); List bufferFields = new ArrayList(); bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); @@ -65,14 +69,20 @@ public MyDoubleSum() { } @Override public void initialize(MutableAggregationBuffer buffer) { + // The initial value of the sum is null. buffer.update(0, null); } @Override public void update(MutableAggregationBuffer buffer, Row input) { + // This input Row only has a single column storing the input value in Double. + // We only update the buffer when the input value is not null. if (!input.isNullAt(0)) { if (buffer.isNullAt(0)) { + // If the buffer value (the intermediate result of the sum) is still null, + // we set the input value to the buffer. buffer.update(0, input.getDouble(0)); } else { + // Otherwise, we add the input value to the buffer value. Double newValue = input.getDouble(0) + buffer.getDouble(0); buffer.update(0, newValue); } @@ -80,10 +90,16 @@ public MyDoubleSum() { } @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + // buffer1 and buffer2 have the same structure. + // We only update the buffer1 when the input buffer2's value is not null. if (!buffer2.isNullAt(0)) { if (buffer1.isNullAt(0)) { + // If the buffer value (intermediate result of the sum) is still null, + // we set the it as the input buffer's value. buffer1.update(0, buffer2.getDouble(0)); } else { + // Otherwise, we add the input buffer's value (buffer1) to the mutable + // buffer's value (buffer2). Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); buffer1.update(0, newValue); } @@ -92,8 +108,10 @@ public MyDoubleSum() { @Override public Object evaluate(Row buffer) { if (buffer.isNullAt(0)) { + // If the buffer value is still null, we return null. return null; } else { + // Otherwise, the intermediate sum is the final result. return buffer.getDouble(0); } } From 8e7d2bee23dad1535846dae2dc31e35058db16cd Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 27 Jul 2015 13:28:03 -0700 Subject: [PATCH 091/219] [SPARK-9378] [SQL] Fixes test case "CTAS with serde" This is a proper version of PR #7693 authored by viirya The reason why "CTAS with serde" fails is that the `MetastoreRelation` gets converted to a Parquet data source relation by default. Author: Cheng Lian Closes #7700 from liancheng/spark-9378-fix-ctas-test and squashes the following commits: 4413af0 [Cheng Lian] Fixes test case "CTAS with serde" --- .../spark/sql/hive/execution/SQLQuerySuite.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 8371dd0716c06..c4923d83e48f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -406,13 +406,15 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils { | FROM src | ORDER BY key, value""".stripMargin).collect() - checkExistence(sql("DESC EXTENDED ctas5"), true, - "name:key", "type:string", "name:value", "ctas5", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", - "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", - "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", - "MANAGED_TABLE" - ) + withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { + checkExistence(sql("DESC EXTENDED ctas5"), true, + "name:key", "type:string", "name:value", "ctas5", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "MANAGED_TABLE" + ) + } // use the Hive SerDe for parquet tables withSQLConf(HiveContext.CONVERT_METASTORE_PARQUET.key -> "false") { From 3ab7525dceeb1c2f3c21efb1ee5a9c8bb0fd0c13 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 27 Jul 2015 13:40:50 -0700 Subject: [PATCH 092/219] [SPARK-9355][SQL] Remove InternalRow.get generic getter call in columnar cache code Author: Wenchen Fan Closes #7673 from cloud-fan/row-generic-getter-columnar and squashes the following commits: 88b1170 [Wenchen Fan] fix style eeae712 [Wenchen Fan] Remove Internal.get generic getter call in columnar cache code --- .../spark/sql/columnar/ColumnAccessor.scala | 12 ++--- .../spark/sql/columnar/ColumnBuilder.scala | 18 +++---- .../spark/sql/columnar/ColumnStats.scala | 6 ++- .../spark/sql/columnar/ColumnType.scala | 49 +++++++++++-------- .../compression/CompressionScheme.scala | 2 +- .../compression/compressionSchemes.scala | 14 +++--- .../spark/sql/columnar/ColumnStatsSuite.scala | 12 ++--- .../spark/sql/columnar/ColumnTypeSuite.scala | 30 ++++++------ .../sql/columnar/ColumnarTestUtils.scala | 18 +++---- .../NullableColumnAccessorSuite.scala | 18 +++---- .../columnar/NullableColumnBuilderSuite.scala | 21 ++++---- .../compression/BooleanBitSetSuite.scala | 2 +- 12 files changed, 107 insertions(+), 95 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 931469bed634a..4c29a093218a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -41,9 +41,9 @@ private[sql] trait ColumnAccessor { protected def underlyingBuffer: ByteBuffer } -private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( +private[sql] abstract class BasicColumnAccessor[JvmType]( protected val buffer: ByteBuffer, - protected val columnType: ColumnType[T, JvmType]) + protected val columnType: ColumnType[JvmType]) extends ColumnAccessor { protected def initialize() {} @@ -93,14 +93,14 @@ private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) + extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) with NullableColumnAccessor private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) -private[sql] class GenericColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) +private[sql] class GenericColumnAccessor(buffer: ByteBuffer, dataType: DataType) + extends BasicColumnAccessor[Array[Byte]](buffer, GENERIC(dataType)) with NullableColumnAccessor private[sql] class DateColumnAccessor(buffer: ByteBuffer) @@ -131,7 +131,7 @@ private[sql] object ColumnAccessor { case BinaryType => new BinaryColumnAccessor(dup) case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnAccessor(dup, precision, scale) - case _ => new GenericColumnAccessor(dup) + case other => new GenericColumnAccessor(dup, other) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 087c52239713d..454b7b91a63f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -46,9 +46,9 @@ private[sql] trait ColumnBuilder { def build(): ByteBuffer } -private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( +private[sql] class BasicColumnBuilder[JvmType]( val columnStats: ColumnStats, - val columnType: ColumnType[T, JvmType]) + val columnType: ColumnType[JvmType]) extends ColumnBuilder { protected var columnName: String = _ @@ -78,16 +78,16 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( } } -private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( +private[sql] abstract class ComplexColumnBuilder[JvmType]( columnStats: ColumnStats, - columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](columnStats, columnType) + columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](columnStats, columnType) with NullableColumnBuilder private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T, T#InternalType](columnStats, columnType) + extends BasicColumnBuilder[T#InternalType](columnStats, columnType) with NullableColumnBuilder with AllCompressionSchemes with CompressibleColumnBuilder[T] @@ -118,8 +118,8 @@ private[sql] class FixedDecimalColumnBuilder( FIXED_DECIMAL(precision, scale)) // TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) +private[sql] class GenericColumnBuilder(dataType: DataType) + extends ComplexColumnBuilder(new GenericColumnStats(dataType), GENERIC(dataType)) private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) @@ -164,7 +164,7 @@ private[sql] object ColumnBuilder { case BinaryType => new BinaryColumnBuilder case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnBuilder(precision, scale) - case _ => new GenericColumnBuilder + case other => new GenericColumnBuilder(other) } builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 7c63179af6470..32a84b2676e07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -252,11 +252,13 @@ private[sql] class FixedDecimalColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class GenericColumnStats extends ColumnStats { +private[sql] class GenericColumnStats(dataType: DataType) extends ColumnStats { + val columnType = GENERIC(dataType) + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - sizeInBytes += GENERIC.actualSize(row, ordinal) + sizeInBytes += columnType.actualSize(row, ordinal) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index c0ca52751b66c..2863f6c230a9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -31,14 +31,18 @@ import org.apache.spark.unsafe.types.UTF8String * An abstract class that represents type of a column. Used to append/extract Java objects into/from * the underlying [[ByteBuffer]] of a column. * - * @param typeId A unique ID representing the type. - * @param defaultSize Default size in bytes for one element of type T (e.g. 4 for `Int`). - * @tparam T Scala data type for the column. * @tparam JvmType Underlying Java type to represent the elements. */ -private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( - val typeId: Int, - val defaultSize: Int) { +private[sql] sealed abstract class ColumnType[JvmType] { + + // The catalyst data type of this column. + def dataType: DataType + + // A unique ID representing the type. + def typeId: Int + + // Default size in bytes for one element of type T (e.g. 4 for `Int`). + def defaultSize: Int /** * Extracts a value out of the buffer at the buffer's current position. @@ -90,7 +94,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * boxing/unboxing costs whenever possible. */ def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { - to(toOrdinal) = from.get(fromOrdinal) + to.update(toOrdinal, from.get(fromOrdinal, dataType)) } /** @@ -103,9 +107,9 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( private[sql] abstract class NativeColumnType[T <: AtomicType]( val dataType: T, - typeId: Int, - defaultSize: Int) - extends ColumnType[T, T#InternalType](typeId, defaultSize) { + val typeId: Int, + val defaultSize: Int) + extends ColumnType[T#InternalType] { /** * Scala TypeTag. Can be used to create primitive arrays and hash tables. @@ -400,10 +404,10 @@ private[sql] object FIXED_DECIMAL { val defaultSize = 8 } -private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( - typeId: Int, - defaultSize: Int) - extends ColumnType[T, Array[Byte]](typeId, defaultSize) { +private[sql] sealed abstract class ByteArrayColumnType( + val typeId: Int, + val defaultSize: Int) + extends ColumnType[Array[Byte]] { override def actualSize(row: InternalRow, ordinal: Int): Int = { getField(row, ordinal).length + 4 @@ -421,9 +425,12 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) { +private[sql] object BINARY extends ByteArrayColumnType(11, 16) { + + def dataType: DataType = BooleanType + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { - row(ordinal) = value + row.update(ordinal, value) } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { @@ -434,18 +441,18 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { +private[sql] case class GENERIC(dataType: DataType) extends ByteArrayColumnType(12, 16) { override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { - row(ordinal) = SparkSqlSerializer.deserialize[Any](value) + row.update(ordinal, SparkSqlSerializer.deserialize[Any](value)) } override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { - SparkSqlSerializer.serialize(row.get(ordinal)) + SparkSqlSerializer.serialize(row.get(ordinal, dataType)) } } private[sql] object ColumnType { - def apply(dataType: DataType): ColumnType[_, _] = { + def apply(dataType: DataType): ColumnType[_] = { dataType match { case BooleanType => BOOLEAN case ByteType => BYTE @@ -460,7 +467,7 @@ private[sql] object ColumnType { case BinaryType => BINARY case DecimalType.Fixed(precision, scale) if precision < 19 => FIXED_DECIMAL(precision, scale) - case _ => GENERIC + case other => GENERIC(other) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index 4eaec6d853d4d..b1ef9b2ef7849 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -46,7 +46,7 @@ private[sql] trait Decoder[T <: AtomicType] { private[sql] trait CompressionScheme { def typeId: Int - def supports(columnType: ColumnType[_, _]): Boolean + def supports(columnType: ColumnType[_]): Boolean def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 6150df6930b32..c91d960a0932b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils private[sql] case object PassThrough extends CompressionScheme { override val typeId = 0 - override def supports(columnType: ColumnType[_, _]): Boolean = true + override def supports(columnType: ColumnType[_]): Boolean = true override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { new this.Encoder[T](columnType) @@ -78,7 +78,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { new this.Decoder(buffer, columnType) } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { + override def supports(columnType: ColumnType[_]): Boolean = columnType match { case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true case _ => false } @@ -128,7 +128,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme { while (from.hasRemaining) { columnType.extract(from, value, 0) - if (value.get(0) == currentValue.get(0)) { + if (value.get(0, columnType.dataType) == currentValue.get(0, columnType.dataType)) { currentRun += 1 } else { // Writes current run @@ -189,7 +189,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme { new this.Encoder[T](columnType) } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType match { + override def supports(columnType: ColumnType[_]): Boolean = columnType match { case INT | LONG | STRING => true case _ => false } @@ -304,7 +304,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme { (new this.Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType == BOOLEAN + override def supports(columnType: ColumnType[_]): Boolean = columnType == BOOLEAN class Encoder extends compression.Encoder[BooleanType.type] { private var _uncompressedSize = 0 @@ -392,7 +392,7 @@ private[sql] case object IntDelta extends CompressionScheme { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType == INT + override def supports(columnType: ColumnType[_]): Boolean = columnType == INT class Encoder extends compression.Encoder[IntegerType.type] { protected var _compressedSize: Int = 0 @@ -472,7 +472,7 @@ private[sql] case object LongDelta extends CompressionScheme { (new Encoder).asInstanceOf[compression.Encoder[T]] } - override def supports(columnType: ColumnType[_, _]): Boolean = columnType == LONG + override def supports(columnType: ColumnType[_]): Boolean = columnType == LONG class Encoder extends compression.Encoder[LongType.type] { protected var _compressedSize: Int = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 31e7b0e72e510..4499a7207031d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -58,15 +58,15 @@ class ColumnStatsSuite extends SparkFunSuite { val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.take(10).map(_.get(0).asInstanceOf[T#InternalType]) + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.get(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.get(1)) - assertResult(10, "Wrong null count")(stats.get(2)) - assertResult(20, "Wrong row count")(stats.get(3)) - assertResult(stats.get(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) + assertResult(10, "Wrong null count")(stats.genericGet(2)) + assertResult(20, "Wrong row count")(stats.genericGet(3)) + assertResult(stats.genericGet(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 4d46a657056e0..8f024690efd0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -32,13 +32,15 @@ import org.apache.spark.unsafe.types.UTF8String class ColumnTypeSuite extends SparkFunSuite with Logging { - val DEFAULT_BUFFER_SIZE = 512 + private val DEFAULT_BUFFER_SIZE = 512 + private val MAP_GENERIC = GENERIC(MapType(IntegerType, StringType)) test("defaultSize") { val checks = Map( BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, - STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16) + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, + MAP_GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -48,8 +50,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } test("actualSize") { - def checkActualSize[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def checkActualSize[JvmType]( + columnType: ColumnType[JvmType], value: JvmType, expected: Int): Unit = { @@ -74,7 +76,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) val generic = Map(1 -> "a") - checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) + checkActualSize(MAP_GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } testNativeColumnType(BOOLEAN)( @@ -123,7 +125,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { UTF8String.fromBytes(bytes) }) - testColumnType[BinaryType.type, Array[Byte]]( + testColumnType[Array[Byte]]( BINARY, (buffer: ByteBuffer, bytes: Array[Byte]) => { buffer.putInt(bytes.length).put(bytes) @@ -140,7 +142,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val obj = Map(1 -> "spark", 2 -> "sql") val serializedObj = SparkSqlSerializer.serialize(obj) - GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) + MAP_GENERIC.append(SparkSqlSerializer.serialize(obj), buffer) buffer.rewind() val length = buffer.getInt() @@ -157,7 +159,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(obj, "Deserialized object didn't equal to the original object") { buffer.rewind() - SparkSqlSerializer.deserialize(GENERIC.extract(buffer)) + SparkSqlSerializer.deserialize(MAP_GENERIC.extract(buffer)) } } @@ -170,7 +172,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { val obj = CustomClass(Int.MaxValue, Long.MaxValue) val serializedObj = serializer.serialize(obj).array() - GENERIC.append(serializer.serialize(obj).array(), buffer) + MAP_GENERIC.append(serializer.serialize(obj).array(), buffer) buffer.rewind() val length = buffer.getInt @@ -192,7 +194,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { assertResult(obj, "Custom deserialized object didn't equal the original object") { buffer.rewind() - serializer.deserialize(ByteBuffer.wrap(GENERIC.extract(buffer))) + serializer.deserialize(ByteBuffer.wrap(MAP_GENERIC.extract(buffer))) } } @@ -201,11 +203,11 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { (putter: (ByteBuffer, T#InternalType) => Unit, getter: (ByteBuffer) => T#InternalType): Unit = { - testColumnType[T, T#InternalType](columnType, putter, getter) + testColumnType[T#InternalType](columnType, putter, getter) } - def testColumnType[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def testColumnType[JvmType]( + columnType: ColumnType[JvmType], putter: (ByteBuffer, JvmType) => Unit, getter: (ByteBuffer) => JvmType): Unit = { @@ -262,7 +264,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } } - assertResult(GENERIC) { + assertResult(GENERIC(DecimalType(19, 0))) { ColumnType(DecimalType(19, 0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index d9861339739c9..79bb7d072feb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -31,7 +31,7 @@ object ColumnarTestUtils { row } - def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = { + def makeRandomValue[JvmType](columnType: ColumnType[JvmType]): JvmType = { def randomBytes(length: Int) = { val bytes = new Array[Byte](length) Random.nextBytes(bytes) @@ -58,15 +58,15 @@ object ColumnarTestUtils { } def makeRandomValues( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) - def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = { + def makeRandomValues(columnTypes: Seq[ColumnType[_]]): Seq[Any] = { columnTypes.map(makeRandomValue(_)) } - def makeUniqueRandomValues[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType], + def makeUniqueRandomValues[JvmType]( + columnType: ColumnType[JvmType], count: Int): Seq[JvmType] = { Iterator.iterate(HashSet.empty[JvmType]) { set => @@ -75,10 +75,10 @@ object ColumnarTestUtils { } def makeRandomRow( - head: ColumnType[_ <: DataType, _], - tail: ColumnType[_ <: DataType, _]*): InternalRow = makeRandomRow(Seq(head) ++ tail) + head: ColumnType[_], + tail: ColumnType[_]*): InternalRow = makeRandomRow(Seq(head) ++ tail) - def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): InternalRow = { + def makeRandomRow(columnTypes: Seq[ColumnType[_]]): InternalRow = { val row = new GenericMutableRow(columnTypes.length) makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => row(index) = value diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index d421f4d8d091e..f4f6c7649bfa8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -21,17 +21,17 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{StringType, ArrayType, DataType} -class TestNullableColumnAccessor[T <: DataType, JvmType]( +class TestNullableColumnAccessor[JvmType]( buffer: ByteBuffer, - columnType: ColumnType[T, JvmType]) + columnType: ColumnType[JvmType]) extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor object TestNullableColumnAccessor { - def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) - : TestNullableColumnAccessor[T, JvmType] = { + def apply[JvmType](buffer: ByteBuffer, columnType: ColumnType[JvmType]) + : TestNullableColumnAccessor[JvmType] = { // Skips the column type ID buffer.getInt() new TestNullableColumnAccessor(buffer, columnType) @@ -43,13 +43,13 @@ class NullableColumnAccessorSuite extends SparkFunSuite { Seq( BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) .foreach { testNullableColumnAccessor(_) } - def testNullableColumnAccessor[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnAccessor[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) @@ -75,7 +75,7 @@ class NullableColumnAccessorSuite extends SparkFunSuite { (0 until 4).foreach { _ => assert(accessor.hasNext) accessor.extractTo(row, 0) - assert(row.get(0) === randomRow.get(0)) + assert(row.get(0, columnType.dataType) === randomRow.get(0, columnType.dataType)) assert(accessor.hasNext) accessor.extractTo(row, 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index cd8bf75ff1752..241d09ea205e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ -class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) +class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { - def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) - : TestNullableColumnBuilder[T, JvmType] = { + def apply[JvmType](columnType: ColumnType[JvmType], initialSize: Int = 0) + : TestNullableColumnBuilder[JvmType] = { val builder = new TestNullableColumnBuilder(columnType) builder.initialize(initialSize) builder @@ -39,13 +39,13 @@ class NullableColumnBuilderSuite extends SparkFunSuite { Seq( BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, - STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC(ArrayType(StringType))) .foreach { testNullableColumnBuilder(_) } - def testNullableColumnBuilder[T <: DataType, JvmType]( - columnType: ColumnType[T, JvmType]): Unit = { + def testNullableColumnBuilder[JvmType]( + columnType: ColumnType[JvmType]): Unit = { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -92,13 +92,14 @@ class NullableColumnBuilderSuite extends SparkFunSuite { // For non-null values (0 until 4).foreach { _ => - val actual = if (columnType == GENERIC) { - SparkSqlSerializer.deserialize[Any](GENERIC.extract(buffer)) + val actual = if (columnType.isInstanceOf[GENERIC]) { + SparkSqlSerializer.deserialize[Any](columnType.extract(buffer).asInstanceOf[Array[Byte]]) } else { columnType.extract(buffer) } - assert(actual === randomRow.get(0), "Extracted value didn't equal to the original one") + assert(actual === randomRow.get(0, columnType.dataType), + "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 33092c83a1a1c..9a2948c59ba42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -33,7 +33,7 @@ class BooleanBitSetSuite extends SparkFunSuite { val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) - val values = rows.map(_.get(0)) + val values = rows.map(_.getBoolean(0)) rows.foreach(builder.appendFrom(_, 0)) val buffer = builder.build() From c1be9f309acad4d1b1908fa7800e7ef4f3e872ce Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Mon, 27 Jul 2015 15:16:46 -0700 Subject: [PATCH 093/219] =?UTF-8?q?[SPARK-8988]=20[YARN]=20Make=20sure=20d?= =?UTF-8?q?river=20log=20links=20appear=20in=20secure=20cluste=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …r mode. The NodeReports API currently used does not work in secure mode since we do not get RM tokens. Instead this patch just uses environment vars exported by YARN to create the log links. Author: Hari Shreedharan Closes #7624 from harishreedharan/driver-logs-env and squashes the following commits: 7368c7e [Hari Shreedharan] [SPARK-8988][YARN] Make sure driver log links appear in secure cluster mode. --- .../cluster/YarnClusterSchedulerBackend.scala | 71 +++++-------------- 1 file changed, 17 insertions(+), 54 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 33f580aaebdc0..1aed5a1675075 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler.cluster import java.net.NetworkInterface +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + import scala.collection.JavaConverters._ import org.apache.hadoop.yarn.api.records.NodeState @@ -64,68 +66,29 @@ private[spark] class YarnClusterSchedulerBackend( } override def getDriverLogUrls: Option[Map[String, String]] = { - var yarnClientOpt: Option[YarnClient] = None var driverLogs: Option[Map[String, String]] = None try { val yarnConf = new YarnConfiguration(sc.hadoopConfiguration) val containerId = YarnSparkHadoopUtil.get.getContainerId - yarnClientOpt = Some(YarnClient.createYarnClient()) - yarnClientOpt.foreach { yarnClient => - yarnClient.init(yarnConf) - yarnClient.start() - - // For newer versions of YARN, we can find the HTTP address for a given node by getting a - // container report for a given container. But container reports came only in Hadoop 2.4, - // so we basically have to get the node reports for all nodes and find the one which runs - // this container. For that we have to compare the node's host against the current host. - // Since the host can have multiple addresses, we need to compare against all of them to - // find out if one matches. - - // Get all the addresses of this node. - val addresses = - NetworkInterface.getNetworkInterfaces.asScala - .flatMap(_.getInetAddresses.asScala) - .toSeq - - // Find a node report that matches one of the addresses - val nodeReport = - yarnClient.getNodeReports(NodeState.RUNNING).asScala.find { x => - val host = x.getNodeId.getHost - addresses.exists { address => - address.getHostAddress == host || - address.getHostName == host || - address.getCanonicalHostName == host - } - } - // Now that we have found the report for the Node Manager that the AM is running on, we - // can get the base HTTP address for the Node manager from the report. - // The format used for the logs for each container is well-known and can be constructed - // using the NM's HTTP address and the container ID. - // The NM may be running several containers, but we can build the URL for the AM using - // the AM's container ID, which we already know. - nodeReport.foreach { report => - val httpAddress = report.getHttpAddress - // lookup appropriate http scheme for container log urls - val yarnHttpPolicy = yarnConf.get( - YarnConfiguration.YARN_HTTP_POLICY_KEY, - YarnConfiguration.YARN_HTTP_POLICY_DEFAULT - ) - val user = Utils.getCurrentUserName() - val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" - val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" - logDebug(s"Base URL for logs: $baseUrl") - driverLogs = Some(Map( - "stderr" -> s"$baseUrl/stderr?start=-4096", - "stdout" -> s"$baseUrl/stdout?start=-4096")) - } - } + val httpAddress = System.getenv(Environment.NM_HOST.name()) + + ":" + System.getenv(Environment.NM_HTTP_PORT.name()) + // lookup appropriate http scheme for container log urls + val yarnHttpPolicy = yarnConf.get( + YarnConfiguration.YARN_HTTP_POLICY_KEY, + YarnConfiguration.YARN_HTTP_POLICY_DEFAULT + ) + val user = Utils.getCurrentUserName() + val httpScheme = if (yarnHttpPolicy == "HTTPS_ONLY") "https://" else "http://" + val baseUrl = s"$httpScheme$httpAddress/node/containerlogs/$containerId/$user" + logDebug(s"Base URL for logs: $baseUrl") + driverLogs = Some(Map( + "stderr" -> s"$baseUrl/stderr?start=-4096", + "stdout" -> s"$baseUrl/stdout?start=-4096")) } catch { case e: Exception => - logInfo("Node Report API is not available in the version of YARN being used, so AM" + + logInfo("Error while building AM log links, so AM" + " logs link will not appear in application UI", e) - } finally { - yarnClientOpt.foreach(_.close()) } driverLogs } From 2104931d7d726eda2c098e0f403c7f1533df8746 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 27 Jul 2015 15:18:48 -0700 Subject: [PATCH 094/219] [SPARK-9385] [HOT-FIX] [PYSPARK] Comment out Python style check https://issues.apache.org/jira/browse/SPARK-9385 Comment out Python style check because of error shown in https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/3088/AMPLAB_JENKINS_BUILD_PROFILE=hadoop1.0,label=centos/console Author: Yin Huai Closes #7702 from yhuai/SPARK-9385 and squashes the following commits: 146e6ef [Yin Huai] Comment out Python style check because of error shown in https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-SBT/3088/AMPLAB_JENKINS_BUILD_PROFILE=hadoop1.0,label=centos/console --- dev/run-tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 1f0d218514f92..d1cb66860b3f8 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -198,8 +198,9 @@ def run_scala_style_checks(): def run_python_style_checks(): - set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") - run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) + # set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") + # run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) + pass def build_spark_documentation(): From ab625956616664c2b4861781a578311da75a9ae4 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 27 Jul 2015 15:46:35 -0700 Subject: [PATCH 095/219] [SPARK-4352] [YARN] [WIP] Incorporate locality preferences in dynamic allocation requests Currently there's no locality preference for container request in YARN mode, this will affect the performance if fetching data remotely, so here proposed to add locality in Yarn dynamic allocation mode. Ping sryza, please help to review, thanks a lot. Author: jerryshao Closes #6394 from jerryshao/SPARK-4352 and squashes the following commits: d45fecb [jerryshao] Add documents 6c3fe5c [jerryshao] Fix bug 8db6c0e [jerryshao] Further address the comments 2e2b2cb [jerryshao] Fix rebase compiling problem ce5f096 [jerryshao] Fix style issue 7f7df95 [jerryshao] Fix rebase issue 9ca9e07 [jerryshao] Code refactor according to comments d3e4236 [jerryshao] Further address the comments 5e7a593 [jerryshao] Fix bug introduced code rebase 9ca7783 [jerryshao] Style changes 08317f9 [jerryshao] code and comment refines 65b2423 [jerryshao] Further address the comments a27c587 [jerryshao] address the comment 27faabc [jerryshao] redundant code remove 9ce06a1 [jerryshao] refactor the code f5ba27b [jerryshao] Style fix 2c6cc8a [jerryshao] Fix bug and add unit tests 0757335 [jerryshao] Consider the distribution of existed containers to recalculate the new container requests 0ad66ff [jerryshao] Fix compile bugs 1c20381 [jerryshao] Minor fix 5ef2dc8 [jerryshao] Add docs and improve the code 3359814 [jerryshao] Fix rebase and test bugs 0398539 [jerryshao] reinitialize the new implementation 67596d6 [jerryshao] Still fix the code 654e1d2 [jerryshao] Fix some bugs 45b1c89 [jerryshao] Further polish the algorithm dea0152 [jerryshao] Enable node locality information in YarnAllocator 74bbcc6 [jerryshao] Support node locality for dynamic allocation initial commit --- .../spark/ExecutorAllocationClient.scala | 18 +- .../spark/ExecutorAllocationManager.scala | 62 +++++- .../scala/org/apache/spark/SparkContext.scala | 25 ++- .../apache/spark/scheduler/DAGScheduler.scala | 26 ++- .../org/apache/spark/scheduler/Stage.scala | 7 +- .../apache/spark/scheduler/StageInfo.scala | 13 +- .../cluster/CoarseGrainedClusterMessage.scala | 6 +- .../CoarseGrainedSchedulerBackend.scala | 32 ++- .../cluster/YarnSchedulerBackend.scala | 3 +- .../ExecutorAllocationManagerSuite.scala | 55 +++++- .../apache/spark/HeartbeatReceiverSuite.scala | 7 +- .../spark/deploy/yarn/ApplicationMaster.scala | 5 +- ...yPreferredContainerPlacementStrategy.scala | 182 ++++++++++++++++++ .../spark/deploy/yarn/YarnAllocator.scala | 47 ++++- .../ContainerPlacementStrategySuite.scala | 125 ++++++++++++ .../deploy/yarn/YarnAllocatorSuite.scala | 14 +- 16 files changed, 578 insertions(+), 49 deletions(-) create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala create mode 100644 yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 443830f8d03b6..842bfdbadc948 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -24,11 +24,23 @@ package org.apache.spark private[spark] trait ExecutorAllocationClient { /** - * Express a preference to the cluster manager for a given total number of executors. - * This can result in canceling pending requests or filing additional requests. + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. * @return whether the request is acknowledged by the cluster manager. */ - private[spark] def requestTotalExecutors(numExecutors: Int): Boolean + private[spark] def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean /** * Request an additional number of executors from the cluster manager. diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 648bcfe28cad2..1877aaf2cac55 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -161,6 +161,12 @@ private[spark] class ExecutorAllocationManager( // (2) an executor idle timeout has elapsed. @volatile private var initializing: Boolean = true + // Number of locality aware tasks, used for executor placement. + private var localityAwareTasks = 0 + + // Host to possible task running on it, used for executor placement. + private var hostToLocalTaskCount: Map[String, Int] = Map.empty + /** * Verify that the settings specified through the config are valid. * If not, throw an appropriate exception. @@ -295,7 +301,7 @@ private[spark] class ExecutorAllocationManager( // If the new target has not changed, avoid sending a message to the cluster manager if (numExecutorsTarget < oldNumExecutorsTarget) { - client.requestTotalExecutors(numExecutorsTarget) + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + s"$oldNumExecutorsTarget) because not all requested executors are actually needed") } @@ -349,7 +355,8 @@ private[spark] class ExecutorAllocationManager( return 0 } - val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget) + val addRequestAcknowledged = testing || + client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) if (addRequestAcknowledged) { val executorsString = "executor" + { if (delta > 1) "s" else "" } logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + @@ -519,6 +526,12 @@ private[spark] class ExecutorAllocationManager( // Number of tasks currently running on the cluster. Should be 0 when no stages are active. private var numRunningTasks: Int = _ + // stageId to tuple (the number of task with locality preferences, a map where each pair is a + // node and the number of tasks that would like to be scheduled on that node) map, + // maintain the executor placement hints for each stage Id used by resource framework to better + // place the executors. + private val stageIdToExecutorPlacementHints = new mutable.HashMap[Int, (Int, Map[String, Int])] + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { initializing = false val stageId = stageSubmitted.stageInfo.stageId @@ -526,6 +539,24 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageIdToNumTasks(stageId) = numTasks allocationManager.onSchedulerBacklogged() + + // Compute the number of tasks requested by the stage on each host + var numTasksPending = 0 + val hostToLocalTaskCountPerStage = new mutable.HashMap[String, Int]() + stageSubmitted.stageInfo.taskLocalityPreferences.foreach { locality => + if (!locality.isEmpty) { + numTasksPending += 1 + locality.foreach { location => + val count = hostToLocalTaskCountPerStage.getOrElse(location.host, 0) + 1 + hostToLocalTaskCountPerStage(location.host) = count + } + } + } + stageIdToExecutorPlacementHints.put(stageId, + (numTasksPending, hostToLocalTaskCountPerStage.toMap)) + + // Update the executor placement hints + updateExecutorPlacementHints() } } @@ -534,6 +565,10 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageIdToNumTasks -= stageId stageIdToTaskIndices -= stageId + stageIdToExecutorPlacementHints -= stageId + + // Update the executor placement hints + updateExecutorPlacementHints() // If this is the last stage with pending tasks, mark the scheduler queue as empty // This is needed in case the stage is aborted for any reason @@ -637,6 +672,29 @@ private[spark] class ExecutorAllocationManager( def isExecutorIdle(executorId: String): Boolean = { !executorIdToTaskIds.contains(executorId) } + + /** + * Update the Executor placement hints (the number of tasks with locality preferences, + * a map where each pair is a node and the number of tasks that would like to be scheduled + * on that node). + * + * These hints are updated when stages arrive and complete, so are not up-to-date at task + * granularity within stages. + */ + def updateExecutorPlacementHints(): Unit = { + var localityAwareTasks = 0 + val localityToCount = new mutable.HashMap[String, Int]() + stageIdToExecutorPlacementHints.values.foreach { case (numTasksPending, localities) => + localityAwareTasks += numTasksPending + localities.foreach { case (hostname, count) => + val updatedCount = localityToCount.getOrElse(hostname, 0) + count + localityToCount(hostname) = updatedCount + } + } + + allocationManager.localityAwareTasks = localityAwareTasks + allocationManager.hostToLocalTaskCount = localityToCount.toMap + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6a6b94a271cfc..ac6ac6c216767 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1382,16 +1382,29 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Express a preference to the cluster manager for a given total number of executors. - * This can result in canceling pending requests or filing additional requests. - * This is currently only supported in YARN mode. Return whether the request is received. - */ - private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = { + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. + */ + private[spark] override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: scala.collection.immutable.Map[String, Int] + ): Boolean = { assert(supportDynamicAllocation, "Requesting executors is currently only supported in YARN and Mesos modes") schedulerBackend match { case b: CoarseGrainedSchedulerBackend => - b.requestTotalExecutors(numExecutors) + b.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount) case _ => logWarning("Requesting executors is only supported in coarse-grained mode") false diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b6a833bbb0833..cdf6078421123 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -790,8 +790,28 @@ class DAGScheduler( // serializable. If tasks are not serializable, a SparkListenerStageCompleted event // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. - stage.makeNewStageAttempt(partitionsToCompute.size) outputCommitCoordinator.stageStart(stage.id) + val taskIdToLocations = try { + stage match { + case s: ShuffleMapStage => + partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap + case s: ResultStage => + val job = s.resultOfJob.get + partitionsToCompute.map { id => + val p = job.partitions(id) + (id, getPreferredLocs(stage.rdd, p)) + }.toMap + } + } catch { + case NonFatal(e) => + stage.makeNewStageAttempt(partitionsToCompute.size) + listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return + } + + stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -830,7 +850,7 @@ class DAGScheduler( stage match { case stage: ShuffleMapStage => partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) + val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs) } @@ -840,7 +860,7 @@ class DAGScheduler( partitionsToCompute.map { id => val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) + val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index b86724de2cb73..40a333a3e06b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -77,8 +77,11 @@ private[spark] abstract class Stage( private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ - def makeNewStageAttempt(numPartitionsToCompute: Int): Unit = { - _latestInfo = StageInfo.fromStage(this, nextAttemptId, Some(numPartitionsToCompute)) + def makeNewStageAttempt( + numPartitionsToCompute: Int, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = { + _latestInfo = StageInfo.fromStage( + this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences) nextAttemptId += 1 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 5d2abbc67e9d9..24796c14300b1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -34,7 +34,8 @@ class StageInfo( val numTasks: Int, val rddInfos: Seq[RDDInfo], val parentIds: Seq[Int], - val details: String) { + val details: String, + private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when all tasks in the stage completed or when the stage was cancelled. */ @@ -70,7 +71,12 @@ private[spark] object StageInfo { * shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a * sequence of narrow dependencies should also be associated with this Stage. */ - def fromStage(stage: Stage, attemptId: Int, numTasks: Option[Int] = None): StageInfo = { + def fromStage( + stage: Stage, + attemptId: Int, + numTasks: Option[Int] = None, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos new StageInfo( @@ -80,6 +86,7 @@ private[spark] object StageInfo { numTasks.getOrElse(stage.numTasks), rddInfos, stage.parents.map(_.id), - stage.details) + stage.details, + taskLocalityPreferences) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 4be1eda2e9291..06f5438433b6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -86,7 +86,11 @@ private[spark] object CoarseGrainedClusterMessages { // Request executors by specifying the new total number of executors desired // This includes executors already pending or running - case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage + case class RequestExecutors( + requestedTotal: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]) + extends CoarseGrainedClusterMessage case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index c65b3e517773e..660702f6e6fd0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -66,6 +66,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] + // A map to store hostname with its possible task number running on it + protected var hostToLocalTaskCount: Map[String, Int] = Map.empty + + // The number of pending tasks which is locality required + protected var localityAwareTasks = 0 + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -339,6 +345,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") logDebug(s"Number of pending executors is now $numPendingExecutors") + numPendingExecutors += numAdditionalExecutors // Account for executors pending to be added or removed val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size @@ -346,16 +353,33 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } /** - * Express a preference to the cluster manager for a given total number of executors. This can - * result in canceling pending requests or filing additional requests. - * @return whether the request is acknowledged. + * Update the cluster manager on our scheduling needs. Three bits of information are included + * to help it make decisions. + * @param numExecutors The total number of executors we'd like to have. The cluster manager + * shouldn't kill any running executor to reach this number, but, + * if all existing executors were to die, this is the number of executors + * we'd want to be allocated. + * @param localityAwareTasks The number of tasks in all active stages that have a locality + * preferences. This includes running, pending, and completed tasks. + * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages + * that would like to like to run on that host. + * This includes running, pending, and completed tasks. + * @return whether the request is acknowledged by the cluster manager. */ - final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized { + final override def requestTotalExecutors( + numExecutors: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int] + ): Boolean = synchronized { if (numExecutors < 0) { throw new IllegalArgumentException( "Attempted to request a negative number of executor(s) " + s"$numExecutors from the cluster manager. Please specify a positive number!") } + + this.localityAwareTasks = localityAwareTasks + this.hostToLocalTaskCount = hostToLocalTaskCount + numPendingExecutors = math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) doRequestTotalExecutors(numExecutors) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index bc67abb5df446..074282d1be37d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -53,7 +53,8 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - yarnSchedulerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) + yarnSchedulerEndpoint.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } /** diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 803e1831bb269..34caca892891c 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -751,6 +751,42 @@ class ExecutorAllocationManagerSuite assert(numExecutorsTarget(manager) === 2) } + test("get pending task number and related locality preference") { + sc = createSparkContext(2, 5, 3) + val manager = sc.executorAllocationManager.get + + val localityPreferences1 = Seq( + Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host3")), + Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host4")), + Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host4")), + Seq.empty, + Seq.empty + ) + val stageInfo1 = createStageInfo(1, 5, localityPreferences1) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo1)) + + assert(localityAwareTasks(manager) === 3) + assert(hostToLocalTaskCount(manager) === + Map("host1" -> 2, "host2" -> 3, "host3" -> 2, "host4" -> 2)) + + val localityPreferences2 = Seq( + Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host5")), + Seq(TaskLocation("host3"), TaskLocation("host4"), TaskLocation("host5")), + Seq.empty + ) + val stageInfo2 = createStageInfo(2, 3, localityPreferences2) + sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo2)) + + assert(localityAwareTasks(manager) === 5) + assert(hostToLocalTaskCount(manager) === + Map("host1" -> 2, "host2" -> 4, "host3" -> 4, "host4" -> 3, "host5" -> 2)) + + sc.listenerBus.postToAll(SparkListenerStageCompleted(stageInfo1)) + assert(localityAwareTasks(manager) === 2) + assert(hostToLocalTaskCount(manager) === + Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) + } + private def createSparkContext( minExecutors: Int = 1, maxExecutors: Int = 5, @@ -784,8 +820,13 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val sustainedSchedulerBacklogTimeout = 2L private val executorIdleTimeout = 3L - private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = { - new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details") + private def createStageInfo( + stageId: Int, + numTasks: Int, + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + ): StageInfo = { + new StageInfo( + stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", taskLocalityPreferences) } private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { @@ -815,6 +856,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private val _onSchedulerQueueEmpty = PrivateMethod[Unit]('onSchedulerQueueEmpty) private val _onExecutorIdle = PrivateMethod[Unit]('onExecutorIdle) private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy) + private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks) + private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount) private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { manager invokePrivate _numExecutorsToAdd() @@ -885,4 +928,12 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { private def onExecutorBusy(manager: ExecutorAllocationManager, id: String): Unit = { manager invokePrivate _onExecutorBusy(id) } + + private def localityAwareTasks(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _localityAwareTasks() + } + + private def hostToLocalTaskCount(manager: ExecutorAllocationManager): Map[String, Int] = { + manager invokePrivate _hostToLocalTaskCount() + } } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 5a2670e4d1cf0..139b8dc25f4b4 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -182,7 +182,7 @@ class HeartbeatReceiverSuite // Adjust the target number of executors on the cluster manager side assert(fakeClusterManager.getTargetNumExecutors === 0) - sc.requestTotalExecutors(2) + sc.requestTotalExecutors(2, 0, Map.empty) assert(fakeClusterManager.getTargetNumExecutors === 2) assert(fakeClusterManager.getExecutorIdsToKill.isEmpty) @@ -241,7 +241,8 @@ private class FakeSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - clusterManagerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal)) + clusterManagerEndpoint.askWithRetry[Boolean]( + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) } protected override def doKillExecutors(executorIds: Seq[String]): Boolean = { @@ -260,7 +261,7 @@ private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoin def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RequestExecutors(requestedTotal) => + case RequestExecutors(requestedTotal, _, _) => targetNumExecutors = requestedTotal context.reply(true) case KillExecutors(executorIds) => diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 83dafa4a125d2..44acc7374d024 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -555,11 +555,12 @@ private[spark] class ApplicationMaster( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RequestExecutors(requestedTotal) => + case RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount) => Option(allocator) match { case Some(a) => allocatorLock.synchronized { - if (a.requestTotalExecutors(requestedTotal)) { + if (a.requestTotalExecutorsWithPreferredLocalities(requestedTotal, + localityAwareTasks, hostToLocalTaskCount)) { allocatorLock.notifyAll() } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala new file mode 100644 index 0000000000000..081780204e424 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/LocalityPreferredContainerPlacementStrategy.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import scala.collection.mutable.{ArrayBuffer, HashMap, Set} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.records.{ContainerId, Resource} +import org.apache.hadoop.yarn.util.RackResolver + +import org.apache.spark.SparkConf + +private[yarn] case class ContainerLocalityPreferences(nodes: Array[String], racks: Array[String]) + +/** + * This strategy is calculating the optimal locality preferences of YARN containers by considering + * the node ratio of pending tasks, number of required cores/containers and and locality of current + * existing containers. The target of this algorithm is to maximize the number of tasks that + * would run locally. + * + * Consider a situation in which we have 20 tasks that require (host1, host2, host3) + * and 10 tasks that require (host1, host2, host4), besides each container has 2 cores + * and cpus per task is 1, so the required container number is 15, + * and host ratio is (host1: 30, host2: 30, host3: 20, host4: 10). + * + * 1. If requested container number (18) is more than the required container number (15): + * + * requests for 5 containers with nodes: (host1, host2, host3, host4) + * requests for 5 containers with nodes: (host1, host2, host3) + * requests for 5 containers with nodes: (host1, host2) + * requests for 3 containers with no locality preferences. + * + * The placement ratio is 3 : 3 : 2 : 1, and set the additional containers with no locality + * preferences. + * + * 2. If requested container number (10) is less than or equal to the required container number + * (15): + * + * requests for 4 containers with nodes: (host1, host2, host3, host4) + * requests for 3 containers with nodes: (host1, host2, host3) + * requests for 3 containers with nodes: (host1, host2) + * + * The placement ratio is 10 : 10 : 7 : 4, close to expected ratio (3 : 3 : 2 : 1) + * + * 3. If containers exist but none of them can match the requested localities, + * follow the method of 1 and 2. + * + * 4. If containers exist and some of them can match the requested localities. + * For example if we have 1 containers on each node (host1: 1, host2: 1: host3: 1, host4: 1), + * and the expected containers on each node would be (host1: 5, host2: 5, host3: 4, host4: 2), + * so the newly requested containers on each node would be updated to (host1: 4, host2: 4, + * host3: 3, host4: 1), 12 containers by total. + * + * 4.1 If requested container number (18) is more than newly required containers (12). Follow + * method 1 with updated ratio 4 : 4 : 3 : 1. + * + * 4.2 If request container number (10) is more than newly required containers (12). Follow + * method 2 with updated ratio 4 : 4 : 3 : 1. + * + * 5. If containers exist and existing localities can fully cover the requested localities. + * For example if we have 5 containers on each node (host1: 5, host2: 5, host3: 5, host4: 5), + * which could cover the current requested localities. This algorithm will allocate all the + * requested containers with no localities. + */ +private[yarn] class LocalityPreferredContainerPlacementStrategy( + val sparkConf: SparkConf, + val yarnConf: Configuration, + val resource: Resource) { + + // Number of CPUs per task + private val CPUS_PER_TASK = sparkConf.getInt("spark.task.cpus", 1) + + /** + * Calculate each container's node locality and rack locality + * @param numContainer number of containers to calculate + * @param numLocalityAwareTasks number of locality required tasks + * @param hostToLocalTaskCount a map to store the preferred hostname and possible task + * numbers running on it, used as hints for container allocation + * @return node localities and rack localities, each locality is an array of string, + * the length of localities is the same as number of containers + */ + def localityOfRequestedContainers( + numContainer: Int, + numLocalityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int], + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + ): Array[ContainerLocalityPreferences] = { + val updatedHostToContainerCount = expectedHostToContainerCount( + numLocalityAwareTasks, hostToLocalTaskCount, allocatedHostToContainersMap) + val updatedLocalityAwareContainerNum = updatedHostToContainerCount.values.sum + + // The number of containers to allocate, divided into two groups, one with preferred locality, + // and the other without locality preference. + val requiredLocalityFreeContainerNum = + math.max(0, numContainer - updatedLocalityAwareContainerNum) + val requiredLocalityAwareContainerNum = numContainer - requiredLocalityFreeContainerNum + + val containerLocalityPreferences = ArrayBuffer[ContainerLocalityPreferences]() + if (requiredLocalityFreeContainerNum > 0) { + for (i <- 0 until requiredLocalityFreeContainerNum) { + containerLocalityPreferences += ContainerLocalityPreferences( + null.asInstanceOf[Array[String]], null.asInstanceOf[Array[String]]) + } + } + + if (requiredLocalityAwareContainerNum > 0) { + val largestRatio = updatedHostToContainerCount.values.max + // Round the ratio of preferred locality to the number of locality required container + // number, which is used for locality preferred host calculating. + var preferredLocalityRatio = updatedHostToContainerCount.mapValues { ratio => + val adjustedRatio = ratio.toDouble * requiredLocalityAwareContainerNum / largestRatio + adjustedRatio.ceil.toInt + } + + for (i <- 0 until requiredLocalityAwareContainerNum) { + // Only filter out the ratio which is larger than 0, which means the current host can + // still be allocated with new container request. + val hosts = preferredLocalityRatio.filter(_._2 > 0).keys.toArray + val racks = hosts.map { h => + RackResolver.resolve(yarnConf, h).getNetworkLocation + }.toSet + containerLocalityPreferences += ContainerLocalityPreferences(hosts, racks.toArray) + + // Minus 1 each time when the host is used. When the current ratio is 0, + // which means all the required ratio is satisfied, this host will not be allocated again. + preferredLocalityRatio = preferredLocalityRatio.mapValues(_ - 1) + } + } + + containerLocalityPreferences.toArray + } + + /** + * Calculate the number of executors need to satisfy the given number of pending tasks. + */ + private def numExecutorsPending(numTasksPending: Int): Int = { + val coresPerExecutor = resource.getVirtualCores + (numTasksPending * CPUS_PER_TASK + coresPerExecutor - 1) / coresPerExecutor + } + + /** + * Calculate the expected host to number of containers by considering with allocated containers. + * @param localityAwareTasks number of locality aware tasks + * @param hostToLocalTaskCount a map to store the preferred hostname and possible task + * numbers running on it, used as hints for container allocation + * @return a map with hostname as key and required number of containers on this host as value + */ + private def expectedHostToContainerCount( + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int], + allocatedHostToContainersMap: HashMap[String, Set[ContainerId]] + ): Map[String, Int] = { + val totalLocalTaskNum = hostToLocalTaskCount.values.sum + hostToLocalTaskCount.map { case (host, count) => + val expectedCount = + count.toDouble * numExecutorsPending(localityAwareTasks) / totalLocalTaskNum + val existedCount = allocatedHostToContainersMap.get(host) + .map(_.size) + .getOrElse(0) + + // If existing container can not fully satisfy the expected number of container, + // the required container number is expected count minus existed count. Otherwise the + // required container number is 0. + (host, math.max(0, (expectedCount - existedCount).ceil.toInt)) + } + } +} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 940873fbd046c..6c103394af098 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -96,7 +96,7 @@ private[yarn] class YarnAllocator( // Number of cores per executor. protected val executorCores = args.executorCores // Resource capability requested for each executors - private val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) + private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) private val launcherPool = new ThreadPoolExecutor( // max pool size of Integer.MAX_VALUE is ignored because we use an unbounded queue @@ -127,6 +127,16 @@ private[yarn] class YarnAllocator( } } + // A map to store preferred hostname and possible task numbers running on it. + private var hostToLocalTaskCounts: Map[String, Int] = Map.empty + + // Number of tasks that have locality preferences in active stages + private var numLocalityAwareTasks: Int = 0 + + // A container placement strategy based on pending tasks' locality preference + private[yarn] val containerPlacementStrategy = + new LocalityPreferredContainerPlacementStrategy(sparkConf, conf, resource) + def getNumExecutorsRunning: Int = numExecutorsRunning def getNumExecutorsFailed: Int = numExecutorsFailed @@ -146,10 +156,19 @@ private[yarn] class YarnAllocator( * Request as many executors from the ResourceManager as needed to reach the desired total. If * the requested total is smaller than the current number of running executors, no executors will * be killed. - * + * @param requestedTotal total number of containers requested + * @param localityAwareTasks number of locality aware tasks to be used as container placement hint + * @param hostToLocalTaskCount a map of preferred hostname to possible task counts to be used as + * container placement hint. * @return Whether the new requested total is different than the old value. */ - def requestTotalExecutors(requestedTotal: Int): Boolean = synchronized { + def requestTotalExecutorsWithPreferredLocalities( + requestedTotal: Int, + localityAwareTasks: Int, + hostToLocalTaskCount: Map[String, Int]): Boolean = synchronized { + this.numLocalityAwareTasks = localityAwareTasks + this.hostToLocalTaskCounts = hostToLocalTaskCount + if (requestedTotal != targetNumExecutors) { logInfo(s"Driver requested a total number of $requestedTotal executor(s).") targetNumExecutors = requestedTotal @@ -221,12 +240,20 @@ private[yarn] class YarnAllocator( val numPendingAllocate = getNumPendingAllocate val missing = targetNumExecutors - numPendingAllocate - numExecutorsRunning + // TODO. Consider locality preferences of pending container requests. + // Since the last time we made container requests, stages have completed and been submitted, + // and that the localities at which we requested our pending executors + // no longer apply to our current needs. We should consider to remove all outstanding + // container requests and add requests anew each time to avoid this. if (missing > 0) { logInfo(s"Will request $missing executor containers, each with ${resource.getVirtualCores} " + s"cores and ${resource.getMemory} MB memory including $memoryOverhead MB overhead") - for (i <- 0 until missing) { - val request = createContainerRequest(resource) + val containerLocalityPreferences = containerPlacementStrategy.localityOfRequestedContainers( + missing, numLocalityAwareTasks, hostToLocalTaskCounts, allocatedHostToContainersMap) + + for (locality <- containerLocalityPreferences) { + val request = createContainerRequest(resource, locality.nodes, locality.racks) amClient.addContainerRequest(request) val nodes = request.getNodes val hostStr = if (nodes == null || nodes.isEmpty) "Any" else nodes.last @@ -249,11 +276,14 @@ private[yarn] class YarnAllocator( * Creates a container request, handling the reflection required to use YARN features that were * added in recent versions. */ - private def createContainerRequest(resource: Resource): ContainerRequest = { + protected def createContainerRequest( + resource: Resource, + nodes: Array[String], + racks: Array[String]): ContainerRequest = { nodeLabelConstructor.map { constructor => - constructor.newInstance(resource, null, null, RM_REQUEST_PRIORITY, true: java.lang.Boolean, + constructor.newInstance(resource, nodes, racks, RM_REQUEST_PRIORITY, true: java.lang.Boolean, labelExpression.orNull) - }.getOrElse(new ContainerRequest(resource, null, null, RM_REQUEST_PRIORITY)) + }.getOrElse(new ContainerRequest(resource, nodes, racks, RM_REQUEST_PRIORITY)) } /** @@ -437,7 +467,6 @@ private[yarn] class YarnAllocator( releasedContainers.add(container.getId()) amClient.releaseAssignedContainer(container.getId()) } - } private object YarnAllocator { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala new file mode 100644 index 0000000000000..b7fe4ccc67a38 --- /dev/null +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ContainerPlacementStrategySuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.scalatest.{BeforeAndAfterEach, Matchers} + +import org.apache.spark.SparkFunSuite + +class ContainerPlacementStrategySuite extends SparkFunSuite with Matchers with BeforeAndAfterEach { + + private val yarnAllocatorSuite = new YarnAllocatorSuite + import yarnAllocatorSuite._ + + override def beforeEach() { + yarnAllocatorSuite.beforeEach() + } + + override def afterEach() { + yarnAllocatorSuite.afterEach() + } + + test("allocate locality preferred containers with enough resource and no matched existed " + + "containers") { + // 1. All the locations of current containers cannot satisfy the new requirements + // 2. Current requested container number can fully satisfy the pending tasks. + + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host3" -> 15, "host4" -> 15, "host5" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array( + Array("host3", "host4", "host5"), + Array("host3", "host4", "host5"), + Array("host3", "host4"))) + } + + test("allocate locality preferred containers with enough resource and partially matched " + + "containers") { + // 1. Parts of current containers' locations can satisfy the new requirements + // 2. Current requested container number can fully satisfy the pending tasks. + + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === + Array(null, Array("host2", "host3"), Array("host2", "host3"))) + } + + test("allocate locality preferred containers with limited resource and partially matched " + + "containers") { + // 1. Parts of current containers' locations can satisfy the new requirements + // 2. Current requested container number cannot fully satisfy the pending tasks. + + val handler = createAllocator(3) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(Array("host2", "host3"))) + } + + test("allocate locality preferred containers with fully matched containers") { + // Current containers' locations can fully satisfy the new requirements + + val handler = createAllocator(5) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array( + createContainer("host1"), + createContainer("host1"), + createContainer("host2"), + createContainer("host2"), + createContainer("host3") + )) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 3, 15, Map("host1" -> 15, "host2" -> 15, "host3" -> 10), handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(null, null, null)) + } + + test("allocate containers with no locality preference") { + // Request new container without locality preference + + val handler = createAllocator(2) + handler.updateResourceRequests() + handler.handleAllocatedContainers(Array(createContainer("host1"), createContainer("host2"))) + + val localities = handler.containerPlacementStrategy.localityOfRequestedContainers( + 1, 0, Map.empty, handler.allocatedHostToContainersMap) + + assert(localities.map(_.nodes) === Array(null)) + } +} diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 7509000771d94..37a789fcd375b 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest +import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf @@ -32,8 +33,6 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ import org.apache.spark.scheduler.SplitInfo -import org.scalatest.{BeforeAndAfterEach, Matchers} - class MockResolver extends DNSToSwitchMapping { override def resolve(names: JList[String]): JList[String] = { @@ -171,7 +170,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (4) - handler.requestTotalExecutors(3) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (3) @@ -182,7 +181,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.allocatedContainerToHostMap.get(container.getId).get should be ("host1") handler.allocatedHostToContainersMap.get("host1").get should contain (container.getId) - handler.requestTotalExecutors(2) + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (1) } @@ -193,7 +192,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (0) handler.getNumPendingAllocate should be (4) - handler.requestTotalExecutors(3) + handler.requestTotalExecutorsWithPreferredLocalities(3, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (3) @@ -203,7 +202,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumExecutorsRunning should be (2) - handler.requestTotalExecutors(1) + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) handler.updateResourceRequests() handler.getNumPendingAllocate should be (0) handler.getNumExecutorsRunning should be (2) @@ -219,7 +218,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter val container2 = createContainer("host2") handler.handleAllocatedContainers(Array(container1, container2)) - handler.requestTotalExecutors(1) + handler.requestTotalExecutorsWithPreferredLocalities(1, 0, Map.empty) handler.executorIdToContainer.keys.foreach { id => handler.killExecutor(id ) } val statuses = Seq(container1, container2).map { c => @@ -241,5 +240,4 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter assert(vmemMsg.contains("5.8 GB of 4.2 GB virtual memory used.")) assert(pmemMsg.contains("2.1 MB of 2 GB physical memory used.")) } - } From dafe8d857dff4c61981476282cbfe11f5c008078 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 27 Jul 2015 15:49:42 -0700 Subject: [PATCH 096/219] [SPARK-9385] [PYSPARK] Enable PEP8 but disable installing pylint. Instead of disabling all python style check, we should enable PEP8. So, this PR just comments out the part installing pylint. Author: Yin Huai Closes #7704 from yhuai/SPARK-9385 and squashes the following commits: 0056359 [Yin Huai] Enable PEP8 but disable installing pylint. --- dev/lint-python | 30 +++++++++++++++--------------- dev/run-tests.py | 5 ++--- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/dev/lint-python b/dev/lint-python index 53bccc1fab535..575dbb0ae321b 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -58,21 +58,21 @@ export "PYTHONPATH=$SPARK_ROOT_DIR/dev/pylint" export "PYLINT_HOME=$PYTHONPATH" export "PATH=$PYTHONPATH:$PATH" -if [ ! -d "$PYLINT_HOME" ]; then - mkdir "$PYLINT_HOME" - # Redirect the annoying pylint installation output. - easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" - easy_install_status="$?" - - if [ "$easy_install_status" -ne 0 ]; then - echo "Unable to install pylint locally in \"$PYTHONPATH\"." - cat "$PYLINT_INSTALL_INFO" - exit "$easy_install_status" - fi - - rm "$PYLINT_INSTALL_INFO" - -fi +# if [ ! -d "$PYLINT_HOME" ]; then +# mkdir "$PYLINT_HOME" +# # Redirect the annoying pylint installation output. +# easy_install -d "$PYLINT_HOME" pylint==1.4.4 &>> "$PYLINT_INSTALL_INFO" +# easy_install_status="$?" +# +# if [ "$easy_install_status" -ne 0 ]; then +# echo "Unable to install pylint locally in \"$PYTHONPATH\"." +# cat "$PYLINT_INSTALL_INFO" +# exit "$easy_install_status" +# fi +# +# rm "$PYLINT_INSTALL_INFO" +# +# fi # There is no need to write this output to a file #+ first, but we do so so that the check status can diff --git a/dev/run-tests.py b/dev/run-tests.py index d1cb66860b3f8..1f0d218514f92 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -198,9 +198,8 @@ def run_scala_style_checks(): def run_python_style_checks(): - # set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") - # run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) - pass + set_title_and_block("Running Python style checks", "BLOCK_PYTHON_STYLE") + run_cmd([os.path.join(SPARK_HOME, "dev", "lint-python")]) def build_spark_documentation(): From 8ddfa52c208bf329c2b2c8909c6be04301e36083 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 27 Jul 2015 17:17:49 -0700 Subject: [PATCH 097/219] [SPARK-9230] [ML] Support StringType features in RFormula This adds StringType feature support via OneHotEncoder. As part of this task it was necessary to change RFormula to an Estimator, so that factor levels could be determined from the training dataset. Not sure if I am using uids correctly here, would be good to get reviewer help on that. cc mengxr Umbrella design doc: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit# Author: Eric Liang Closes #7574 from ericl/string-features and squashes the following commits: f99131a [Eric Liang] comments 0bf3c26 [Eric Liang] update docs c302a2c [Eric Liang] fix tests 9d1ac82 [Eric Liang] Merge remote-tracking branch 'upstream/master' into string-features e713da3 [Eric Liang] comments 4d79193 [Eric Liang] revert to seq + distinct 169a085 [Eric Liang] tweak functional test a230a47 [Eric Liang] Merge branch 'master' into string-features 72bd6f3 [Eric Liang] fix merge d841cec [Eric Liang] Merge branch 'master' into string-features 5b2c4a2 [Eric Liang] Mon Jul 20 18:45:33 PDT 2015 b01c7c5 [Eric Liang] add test 8a637db [Eric Liang] encoder wip a1d03f4 [Eric Liang] refactor into estimator --- R/pkg/inst/tests/test_mllib.R | 6 +- .../apache/spark/ml/feature/RFormula.scala | 133 ++++++++++++++---- .../ml/feature/RFormulaParserSuite.scala | 1 + .../spark/ml/feature/RFormulaSuite.scala | 64 +++++---- 4 files changed, 142 insertions(+), 62 deletions(-) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index a492763344ae6..29152a11688a2 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -35,8 +35,8 @@ test_that("glm and predict", { test_that("predictions match with native glm", { training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ Sepal_Length, data = training) + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index f7b46efa10e90..0a95b1ee8de6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -17,17 +17,34 @@ package org.apache.spark.ml.feature +import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.Transformer +import org.apache.spark.ml.{Estimator, Model, Transformer, Pipeline, PipelineModel, PipelineStage} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.VectorUDT import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +/** + * Base trait for [[RFormula]] and [[RFormulaModel]]. + */ +private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { + /** @group getParam */ + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group getParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + protected def hasLabelCol(schema: StructType): Boolean = { + schema.map(_.name).contains($(labelCol)) + } +} + /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently @@ -35,8 +52,7 @@ import org.apache.spark.sql.types._ * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental -class RFormula(override val uid: String) - extends Transformer with HasFeaturesCol with HasLabelCol { +class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { def this() = this(Identifiable.randomUID("rFormula")) @@ -62,19 +78,74 @@ class RFormula(override val uid: String) /** @group getParam */ def getFormula: String = $(formula) - /** @group getParam */ - def setFeaturesCol(value: String): this.type = set(featuresCol, value) + override def fit(dataset: DataFrame): RFormulaModel = { + require(parsedFormula.isDefined, "Must call setFormula() first.") + // StringType terms and terms representing interactions need to be encoded before assembly. + // TODO(ekl) add support for feature interactions + var encoderStages = ArrayBuffer[PipelineStage]() + var tempColumns = ArrayBuffer[String]() + val encodedTerms = parsedFormula.get.terms.map { term => + dataset.schema(term) match { + case column if column.dataType == StringType => + val indexCol = term + "_idx_" + uid + val encodedCol = term + "_onehot_" + uid + encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) + encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) + tempColumns += indexCol + tempColumns += encodedCol + encodedCol + case _ => + term + } + } + encoderStages += new VectorAssembler(uid) + .setInputCols(encodedTerms.toArray) + .setOutputCol($(featuresCol)) + encoderStages += new ColumnPruner(tempColumns.toSet) + val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) + copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this)) + } - /** @group getParam */ - def setLabelCol(value: String): this.type = set(labelCol, value) + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + if (hasLabelCol(schema)) { + StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) + } else { + StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true) :+ + StructField($(labelCol), DoubleType, true)) + } + } + + override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + + override def toString: String = s"RFormula(${get(formula)})" +} + +/** + * :: Experimental :: + * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. + * @param parsedFormula a pre-parsed R formula. + * @param pipelineModel the fitted feature model, including factor to index mappings. + */ +@Experimental +class RFormulaModel private[feature]( + override val uid: String, + parsedFormula: ParsedRFormula, + pipelineModel: PipelineModel) + extends Model[RFormulaModel] with RFormulaBase { + + override def transform(dataset: DataFrame): DataFrame = { + checkCanTransform(dataset.schema) + transformLabel(pipelineModel.transform(dataset)) + } override def transformSchema(schema: StructType): StructType = { checkCanTransform(schema) - val withFeatures = transformFeatures.transformSchema(schema) + val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else if (schema.exists(_.name == parsedFormula.get.label)) { - val nullable = schema(parsedFormula.get.label).dataType match { + } else if (schema.exists(_.name == parsedFormula.label)) { + val nullable = schema(parsedFormula.label).dataType match { case _: NumericType | BooleanType => false case _ => true } @@ -86,24 +157,19 @@ class RFormula(override val uid: String) } } - override def transform(dataset: DataFrame): DataFrame = { - checkCanTransform(dataset.schema) - transformLabel(transformFeatures.transform(dataset)) - } - - override def copy(extra: ParamMap): RFormula = defaultCopy(extra) + override def copy(extra: ParamMap): RFormulaModel = copyValues( + new RFormulaModel(uid, parsedFormula, pipelineModel)) - override def toString: String = s"RFormula(${get(formula)})" + override def toString: String = s"RFormulaModel(${parsedFormula})" private def transformLabel(dataset: DataFrame): DataFrame = { - val labelName = parsedFormula.get.label + val labelName = parsedFormula.label if (hasLabelCol(dataset.schema)) { dataset } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) - // TODO(ekl) add support for string-type labels case other => throw new IllegalArgumentException("Unsupported type for label: " + other) } @@ -114,25 +180,32 @@ class RFormula(override val uid: String) } } - private def transformFeatures: Transformer = { - // TODO(ekl) add support for non-numeric features and feature interactions - new VectorAssembler(uid) - .setInputCols(parsedFormula.get.terms.toArray) - .setOutputCol($(featuresCol)) - } - private def checkCanTransform(schema: StructType) { - require(parsedFormula.isDefined, "Must call setFormula() first.") val columnNames = schema.map(_.name) require(!columnNames.contains($(featuresCol)), "Features column already exists.") require( !columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, "Label column already exists and is not of type DoubleType.") } +} - private def hasLabelCol(schema: StructType): Boolean = { - schema.map(_.name).contains($(labelCol)) +/** + * Utility transformer for removing temporary columns from a DataFrame. + * TODO(ekl) make this a public transformer + */ +private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { + override val uid = Identifiable.randomUID("columnPruner") + + override def transform(dataset: DataFrame): DataFrame = { + val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) + dataset.select(columnsToKeep.map(dataset.col) : _*) } + + override def transformSchema(schema: StructType): StructType = { + StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) + } + + override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) } /** @@ -149,7 +222,7 @@ private[ml] object RFormulaParser extends RegexParsers { def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } def formula: Parser[ParsedRFormula] = - (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) } def parse(value: String): ParsedRFormula = parseAll(formula, value) match { case Success(result, _) => result diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index c8d065f37a605..c4b45aee06384 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -28,6 +28,7 @@ class RFormulaParserSuite extends SparkFunSuite { test("parse simple formulas") { checkParse("y ~ x", "y", Seq("x")) + checkParse("y ~ x + x", "y", Seq("x")) checkParse("y ~ ._foo ", "y", Seq("._foo")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 79c4ccf02d4e0..8148c553e9051 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -31,72 +31,78 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { val formula = new RFormula().setFormula("id ~ v1 + v2") val original = sqlContext.createDataFrame( Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") - val result = formula.transform(original) - val resultSchema = formula.transformSchema(original.schema) + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) val expected = sqlContext.createDataFrame( Seq( - (0, 1.0, 3.0, Vectors.dense(Array(1.0, 3.0)), 0.0), - (2, 2.0, 5.0, Vectors.dense(Array(2.0, 5.0)), 2.0)) + (0, 1.0, 3.0, Vectors.dense(1.0, 3.0), 0.0), + (2, 2.0, 5.0, Vectors.dense(2.0, 5.0), 2.0)) ).toDF("id", "v1", "v2", "features", "label") // TODO(ekl) make schema comparisons ignore metadata, to avoid .toString assert(result.schema.toString == resultSchema.toString) assert(resultSchema == expected.schema) - assert(result.collect().toSeq == expected.collect().toSeq) + assert(result.collect() === expected.collect()) } test("features column already exists") { val formula = new RFormula().setFormula("y ~ x").setFeaturesCol("x") val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") intercept[IllegalArgumentException] { - formula.transformSchema(original.schema) + formula.fit(original) } intercept[IllegalArgumentException] { - formula.transform(original) + formula.fit(original) } } test("label column already exists") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "y") - val resultSchema = formula.transformSchema(original.schema) + val model = formula.fit(original) + val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) - assert(resultSchema.toString == formula.transform(original).schema.toString) + assert(resultSchema.toString == model.transform(original).schema.toString) } test("label column already exists but is not double type") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("y") val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y") + val model = formula.fit(original) intercept[IllegalArgumentException] { - formula.transformSchema(original.schema) + model.transformSchema(original.schema) } intercept[IllegalArgumentException] { - formula.transform(original) + model.transform(original) } } test("allow missing label column for test datasets") { val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") - val resultSchema = formula.transformSchema(original.schema) + val model = formula.fit(original) + val resultSchema = model.transformSchema(original.schema) assert(resultSchema.length == 3) assert(!resultSchema.exists(_.name == "label")) - assert(resultSchema.toString == formula.transform(original).schema.toString) + assert(resultSchema.toString == model.transform(original).schema.toString) } -// TODO(ekl) enable after we implement string label support -// test("transform string label") { -// val formula = new RFormula().setFormula("name ~ id") -// val original = sqlContext.createDataFrame( -// Seq((1, "foo"), (2, "bar"), (3, "bar"))).toDF("id", "name") -// val result = formula.transform(original) -// val resultSchema = formula.transformSchema(original.schema) -// val expected = sqlContext.createDataFrame( -// Seq( -// (1, "foo", Vectors.dense(Array(1.0)), 1.0), -// (2, "bar", Vectors.dense(Array(2.0)), 0.0), -// (3, "bar", Vectors.dense(Array(3.0)), 0.0)) -// ).toDF("id", "name", "features", "label") -// assert(result.schema.toString == resultSchema.toString) -// assert(result.collect().toSeq == expected.collect().toSeq) -// } + test("encodes string terms") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val resultSchema = model.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.schema.toString == resultSchema.toString) + assert(result.collect() === expected.collect()) + } } From ce89ff477aea6def68265ed218f6105680755c9a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 27 Jul 2015 17:32:34 -0700 Subject: [PATCH 098/219] [SPARK-9386] [SQL] Feature flag for metastore partition pruning Since we have been seeing a lot of failures related to this new feature, lets put it behind a flag and turn it off by default. Author: Michael Armbrust Closes #7703 from marmbrus/optionalMetastorePruning and squashes the following commits: 6ad128c [Michael Armbrust] style 8447835 [Michael Armbrust] [SPARK-9386][SQL] Feature flag for metastore partition pruning fd37b87 [Michael Armbrust] add config flag --- .../main/scala/org/apache/spark/sql/SQLConf.scala | 7 +++++++ .../apache/spark/sql/hive/HiveMetastoreCatalog.scala | 12 +++++++++++- .../spark/sql/hive/client/ClientInterface.scala | 10 ++++------ 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 9b2dbd7442f5c..40eba33f595ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -301,6 +301,11 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "") + val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning", + defaultValue = Some(false), + doc = "When true, some predicates will be pushed down into the Hive metastore so that " + + "unmatching partitions can be eliminated earlier.") + val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", defaultValue = Some("_corrupt_record"), doc = "") @@ -456,6 +461,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) + private[spark] def metastorePartitionPruning: Boolean = getConf(HIVE_METASTORE_PARTITION_PRUNING) + private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT) private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9c707a7a2eca1..3180c05445c9f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -678,8 +678,18 @@ private[hive] case class MetastoreRelation } ) + // When metastore partition pruning is turned off, we cache the list of all partitions to + // mimic the behavior of Spark < 1.5 + lazy val allPartitions = table.getAllPartitions + def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = { - table.getPartitions(predicates).map { p => + val rawPartitions = if (sqlContext.conf.metastorePartitionPruning) { + table.getPartitions(predicates) + } else { + allPartitions + } + + rawPartitions.map { p => val tPartition = new org.apache.hadoop.hive.metastore.api.Partition tPartition.setDbName(databaseName) tPartition.setTableName(tableName) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index 1656587d14835..d834b4e83e043 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -72,12 +72,10 @@ private[hive] case class HiveTable( def isPartitioned: Boolean = partitionColumns.nonEmpty - def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = { - predicates match { - case Nil => client.getAllPartitions(this) - case _ => client.getPartitionsByFilter(this, predicates) - } - } + def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this) + + def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = + client.getPartitionsByFilter(this, predicates) // Hive does not support backticks when passing names to the client. def qualifiedName: String = s"$database.$name" From daa1964b6098f79100def78451bda181b5c92198 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 27 Jul 2015 17:59:43 -0700 Subject: [PATCH 099/219] [SPARK-8882] [STREAMING] Add a new Receiver scheduling mechanism The design doc: https://docs.google.com/document/d/1ZsoRvHjpISPrDmSjsGzuSu8UjwgbtmoCTzmhgTurHJw/edit?usp=sharing Author: zsxwing Closes #7276 from zsxwing/receiver-scheduling and squashes the following commits: 137b257 [zsxwing] Add preferredNumExecutors to rescheduleReceiver 61a6c3f [zsxwing] Set state to ReceiverState.INACTIVE in deregisterReceiver 5e1fa48 [zsxwing] Fix the code style 7451498 [zsxwing] Move DummyReceiver back to ReceiverTrackerSuite 715ef9c [zsxwing] Rename: scheduledLocations -> scheduledExecutors; locations -> executors 05daf9c [zsxwing] Use receiverTrackingInfo.toReceiverInfo 1d6d7c8 [zsxwing] Merge branch 'master' into receiver-scheduling 8f93c8d [zsxwing] Use hostPort as the receiver location rather than host; fix comments and unit tests 59f8887 [zsxwing] Schedule all receivers at the same time when launching them 075e0a3 [zsxwing] Add receiver RDD name; use '!isTrackerStarted' instead 276a4ac [zsxwing] Remove "ReceiverLauncher" and move codes to "launchReceivers" fab9a01 [zsxwing] Move methods back to the outer class 4e639c4 [zsxwing] Fix unintentional changes f60d021 [zsxwing] Reorganize ReceiverTracker to use an event loop for lock free 105037e [zsxwing] Merge branch 'master' into receiver-scheduling 5fee132 [zsxwing] Update tha scheduling algorithm to avoid to keep restarting Receiver 9e242c8 [zsxwing] Remove the ScheduleReceiver message because we can refuse it when receiving RegisterReceiver a9acfbf [zsxwing] Merge branch 'squash-pr-6294' into receiver-scheduling 881edb9 [zsxwing] ReceiverScheduler -> ReceiverSchedulingPolicy e530bcc [zsxwing] [SPARK-5681][Streaming] Use a lock to eliminate the race condition when stopping receivers and registering receivers happen at the same time #6294 3b87e4a [zsxwing] Revert SparkContext.scala a86850c [zsxwing] Remove submitAsyncJob and revert JobWaiter f549595 [zsxwing] Add comments for the scheduling approach 9ecc08e [zsxwing] Fix comments and code style 28d1bee [zsxwing] Make 'host' protected; rescheduleReceiver -> getAllowedLocations 2c86a9e [zsxwing] Use tryFailure to support calling jobFailed multiple times ca6fe35 [zsxwing] Add a test for Receiver.restart 27acd45 [zsxwing] Add unit tests for LoadBalanceReceiverSchedulerImplSuite cc76142 [zsxwing] Add JobWaiter.toFuture to avoid blocking threads d9a3e72 [zsxwing] Add a new Receiver scheduling mechanism --- .../receiver/ReceiverSupervisor.scala | 4 +- .../receiver/ReceiverSupervisorImpl.scala | 6 +- .../streaming/scheduler/ReceiverInfo.scala | 1 - .../scheduler/ReceiverSchedulingPolicy.scala | 171 +++++++ .../streaming/scheduler/ReceiverTracker.scala | 468 +++++++++++------- .../scheduler/ReceiverTrackingInfo.scala | 55 ++ .../ReceiverSchedulingPolicySuite.scala | 130 +++++ .../scheduler/ReceiverTrackerSuite.scala | 66 +-- .../StreamingJobProgressListenerSuite.scala | 6 +- 9 files changed, 674 insertions(+), 233 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index a7c220f426ecf..e98017a63756e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.util.control.NonFatal -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{SparkEnv, Logging, SparkConf} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{Utils, ThreadUtils} /** * Abstract class that is responsible for supervising a Receiver in the worker. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 2f6841ee8879c..0d802f83549af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.RpcUtils import org.apache.spark.{Logging, SparkEnv, SparkException} /** @@ -46,6 +46,8 @@ private[streaming] class ReceiverSupervisorImpl( checkpointDirOption: Option[String] ) extends ReceiverSupervisor(receiver, env.conf) with Logging { + private val hostPort = SparkEnv.get.blockManager.blockManagerId.hostPort + private val receivedBlockHandler: ReceivedBlockHandler = { if (WriteAheadLogUtils.enableReceiverLog(env.conf)) { if (checkpointDirOption.isEmpty) { @@ -170,7 +172,7 @@ private[streaming] class ReceiverSupervisorImpl( override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( - streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) + streamId, receiver.getClass.getSimpleName, hostPort, endpoint) trackerEndpoint.askWithRetry[Boolean](msg) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index de85f24dd988d..59df892397fe0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -28,7 +28,6 @@ import org.apache.spark.rpc.RpcEndpointRef case class ReceiverInfo( streamId: Int, name: String, - private[streaming] val endpoint: RpcEndpointRef, active: Boolean, location: String, lastErrorMessage: String = "", diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala new file mode 100644 index 0000000000000..ef5b687b5831a --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicy.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.Map +import scala.collection.mutable + +import org.apache.spark.streaming.receiver.Receiver + +private[streaming] class ReceiverSchedulingPolicy { + + /** + * Try our best to schedule receivers with evenly distributed. However, if the + * `preferredLocation`s of receivers are not even, we may not be able to schedule them evenly + * because we have to respect them. + * + * Here is the approach to schedule executors: + *

    + *
  1. First, schedule all the receivers with preferred locations (hosts), evenly among the + * executors running on those host.
  2. + *
  3. Then, schedule all other receivers evenly among all the executors such that overall + * distribution over all the receivers is even.
  4. + *
+ * + * This method is called when we start to launch receivers at the first time. + */ + def scheduleReceivers( + receivers: Seq[Receiver[_]], executors: Seq[String]): Map[Int, Seq[String]] = { + if (receivers.isEmpty) { + return Map.empty + } + + if (executors.isEmpty) { + return receivers.map(_.streamId -> Seq.empty).toMap + } + + val hostToExecutors = executors.groupBy(_.split(":")(0)) + val scheduledExecutors = Array.fill(receivers.length)(new mutable.ArrayBuffer[String]) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // Set the initial value to 0 + executors.foreach(e => numReceiversOnExecutor(e) = 0) + + // Firstly, we need to respect "preferredLocation". So if a receiver has "preferredLocation", + // we need to make sure the "preferredLocation" is in the candidate scheduled executor list. + for (i <- 0 until receivers.length) { + // Note: preferredLocation is host but executors are host:port + receivers(i).preferredLocation.foreach { host => + hostToExecutors.get(host) match { + case Some(executorsOnHost) => + // preferredLocation is a known host. Select an executor that has the least receivers in + // this host + val leastScheduledExecutor = + executorsOnHost.minBy(executor => numReceiversOnExecutor(executor)) + scheduledExecutors(i) += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = + numReceiversOnExecutor(leastScheduledExecutor) + 1 + case None => + // preferredLocation is an unknown host. + // Note: There are two cases: + // 1. This executor is not up. But it may be up later. + // 2. This executor is dead, or it's not a host in the cluster. + // Currently, simply add host to the scheduled executors. + scheduledExecutors(i) += host + } + } + } + + // For those receivers that don't have preferredLocation, make sure we assign at least one + // executor to them. + for (scheduledExecutorsForOneReceiver <- scheduledExecutors.filter(_.isEmpty)) { + // Select the executor that has the least receivers + val (leastScheduledExecutor, numReceivers) = numReceiversOnExecutor.minBy(_._2) + scheduledExecutorsForOneReceiver += leastScheduledExecutor + numReceiversOnExecutor(leastScheduledExecutor) = numReceivers + 1 + } + + // Assign idle executors to receivers that have less executors + val idleExecutors = numReceiversOnExecutor.filter(_._2 == 0).map(_._1) + for (executor <- idleExecutors) { + // Assign an idle executor to the receiver that has least candidate executors. + val leastScheduledExecutors = scheduledExecutors.minBy(_.size) + leastScheduledExecutors += executor + } + + receivers.map(_.streamId).zip(scheduledExecutors).toMap + } + + /** + * Return a list of candidate executors to run the receiver. If the list is empty, the caller can + * run this receiver in arbitrary executor. The caller can use `preferredNumExecutors` to require + * returning `preferredNumExecutors` executors if possible. + * + * This method tries to balance executors' load. Here is the approach to schedule executors + * for a receiver. + *
    + *
  1. + * If preferredLocation is set, preferredLocation should be one of the candidate executors. + *
  2. + *
  3. + * Every executor will be assigned to a weight according to the receivers running or + * scheduling on it. + *
      + *
    • + * If a receiver is running on an executor, it contributes 1.0 to the executor's weight. + *
    • + *
    • + * If a receiver is scheduled to an executor but has not yet run, it contributes + * `1.0 / #candidate_executors_of_this_receiver` to the executor's weight.
    • + *
    + * At last, if there are more than `preferredNumExecutors` idle executors (weight = 0), + * returns all idle executors. Otherwise, we only return `preferredNumExecutors` best options + * according to the weights. + *
  4. + *
+ * + * This method is called when a receiver is registering with ReceiverTracker or is restarting. + */ + def rescheduleReceiver( + receiverId: Int, + preferredLocation: Option[String], + receiverTrackingInfoMap: Map[Int, ReceiverTrackingInfo], + executors: Seq[String], + preferredNumExecutors: Int = 3): Seq[String] = { + if (executors.isEmpty) { + return Seq.empty + } + + // Always try to schedule to the preferred locations + val scheduledExecutors = mutable.Set[String]() + scheduledExecutors ++= preferredLocation + + val executorWeights = receiverTrackingInfoMap.values.flatMap { receiverTrackingInfo => + receiverTrackingInfo.state match { + case ReceiverState.INACTIVE => Nil + case ReceiverState.SCHEDULED => + val scheduledExecutors = receiverTrackingInfo.scheduledExecutors.get + // The probability that a scheduled receiver will run in an executor is + // 1.0 / scheduledLocations.size + scheduledExecutors.map(location => location -> (1.0 / scheduledExecutors.size)) + case ReceiverState.ACTIVE => Seq(receiverTrackingInfo.runningExecutor.get -> 1.0) + } + }.groupBy(_._1).mapValues(_.map(_._2).sum) // Sum weights for each executor + + val idleExecutors = (executors.toSet -- executorWeights.keys).toSeq + if (idleExecutors.size >= preferredNumExecutors) { + // If there are more than `preferredNumExecutors` idle executors, return all of them + scheduledExecutors ++= idleExecutors + } else { + // If there are less than `preferredNumExecutors` idle executors, return 3 best options + scheduledExecutors ++= idleExecutors + val sortedExecutors = executorWeights.toSeq.sortBy(_._2).map(_._1) + scheduledExecutors ++= (idleExecutors ++ sortedExecutors).take(preferredNumExecutors) + } + scheduledExecutors.toSeq + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 9cc6ffcd12f61..6270137951b5a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,17 +17,27 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap} +import java.util.concurrent.{TimeUnit, CountDownLatch} + +import scala.collection.mutable.HashMap +import scala.concurrent.ExecutionContext import scala.language.existentials -import scala.math.max +import scala.util.{Failure, Success} import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark._ +import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ import org.apache.spark.streaming.{StreamingContext, Time} -import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, - StopReceiver, UpdateRateLimit} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.streaming.receiver._ +import org.apache.spark.util.{ThreadUtils, SerializableConfiguration} + + +/** Enumeration to identify current state of a Receiver */ +private[streaming] object ReceiverState extends Enumeration { + type ReceiverState = Value + val INACTIVE, SCHEDULED, ACTIVE = Value +} /** * Messages used by the NetworkReceiver and the ReceiverTracker to communicate @@ -37,7 +47,7 @@ private[streaming] sealed trait ReceiverTrackerMessage private[streaming] case class RegisterReceiver( streamId: Int, typ: String, - host: String, + hostPort: String, receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) @@ -46,7 +56,38 @@ private[streaming] case class ReportError(streamId: Int, message: String, error: private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String) extends ReceiverTrackerMessage -private[streaming] case object StopAllReceivers extends ReceiverTrackerMessage +/** + * Messages used by the driver and ReceiverTrackerEndpoint to communicate locally. + */ +private[streaming] sealed trait ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to restart a Spark job for the receiver. + */ +private[streaming] case class RestartReceiver(receiver: Receiver[_]) + extends ReceiverTrackerLocalMessage + +/** + * This message is sent to ReceiverTrackerEndpoint when we start to launch Spark jobs for receivers + * at the first time. + */ +private[streaming] case class StartAllReceivers(receiver: Seq[Receiver[_]]) + extends ReceiverTrackerLocalMessage + +/** + * This message will trigger ReceiverTrackerEndpoint to send stop signals to all registered + * receivers. + */ +private[streaming] case object StopAllReceivers extends ReceiverTrackerLocalMessage + +/** + * A message used by ReceiverTracker to ask all receiver's ids still stored in + * ReceiverTrackerEndpoint. + */ +private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessage + +private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long) + extends ReceiverTrackerLocalMessage /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of @@ -60,8 +101,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private val receiverInputStreams = ssc.graph.getReceiverInputStreams() private val receiverInputStreamIds = receiverInputStreams.map { _.id } - private val receiverExecutor = new ReceiverLauncher() - private val receiverInfo = new HashMap[Int, ReceiverInfo] with SynchronizedMap[Int, ReceiverInfo] private val receivedBlockTracker = new ReceivedBlockTracker( ssc.sparkContext.conf, ssc.sparkContext.hadoopConfiguration, @@ -86,6 +125,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // This not being null means the tracker has been started and not stopped private var endpoint: RpcEndpointRef = null + private val schedulingPolicy = new ReceiverSchedulingPolicy() + + // Track the active receiver job number. When a receiver job exits ultimately, countDown will + // be called. + private val receiverJobExitLatch = new CountDownLatch(receiverInputStreams.size) + + /** + * Track all receivers' information. The key is the receiver id, the value is the receiver info. + * It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverTrackingInfos = new HashMap[Int, ReceiverTrackingInfo] + + /** + * Store all preferred locations for all receivers. We need this information to schedule + * receivers. It's only accessed in ReceiverTrackerEndpoint. + */ + private val receiverPreferredLocations = new HashMap[Int, Option[String]] + /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { if (isTrackerStarted) { @@ -95,7 +152,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (!receiverInputStreams.isEmpty) { endpoint = ssc.env.rpcEnv.setupEndpoint( "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) - if (!skipReceiverLaunch) receiverExecutor.start() + if (!skipReceiverLaunch) launchReceivers() logInfo("ReceiverTracker started") trackerState = Started } @@ -112,20 +169,18 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Wait for the Spark job that runs the receivers to be over // That is, for the receivers to quit gracefully. - receiverExecutor.awaitTermination(10000) + receiverJobExitLatch.await(10, TimeUnit.SECONDS) if (graceful) { - val pollTime = 100 logInfo("Waiting for receiver job to terminate gracefully") - while (receiverInfo.nonEmpty || receiverExecutor.running) { - Thread.sleep(pollTime) - } + receiverJobExitLatch.await() logInfo("Waited for receiver job to terminate gracefully") } // Check if all the receivers have been deregistered or not - if (receiverInfo.nonEmpty) { - logWarning("Not all of the receivers have deregistered, " + receiverInfo) + val receivers = endpoint.askWithRetry[Seq[Int]](AllReceiverIds) + if (receivers.nonEmpty) { + logWarning("Not all of the receivers have deregistered, " + receivers) } else { logInfo("All of the receivers have deregistered successfully") } @@ -154,9 +209,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Get the blocks allocated to the given batch and stream. */ def getBlocksOfBatchAndStream(batchTime: Time, streamId: Int): Seq[ReceivedBlockInfo] = { - synchronized { - receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) - } + receivedBlockTracker.getBlocksOfBatchAndStream(batchTime, streamId) } /** @@ -170,8 +223,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - receiverInfo.values.flatMap { info => Option(info.endpoint) } - .foreach { _.send(CleanupOldBlocks(cleanupThreshTime)) } + endpoint.send(CleanupOldBlocks(cleanupThreshTime)) } } @@ -179,7 +231,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false private def registerReceiver( streamId: Int, typ: String, - host: String, + hostPort: String, receiverEndpoint: RpcEndpointRef, senderAddress: RpcAddress ): Boolean = { @@ -189,13 +241,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false if (isTrackerStopping || isTrackerStopped) { false + } else if (!scheduleReceiver(streamId).contains(hostPort)) { + // Refuse it since it's scheduled to a wrong executor + false } else { - // "stopReceivers" won't happen at the same time because both "registerReceiver" and are - // called in the event loop. So here we can assume "stopReceivers" has not yet been called. If - // "stopReceivers" is called later, it should be able to see this receiver. - receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) - listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + val name = s"${typ}-${streamId}" + val receiverTrackingInfo = ReceiverTrackingInfo( + streamId, + ReceiverState.ACTIVE, + scheduledExecutors = None, + runningExecutor = Some(hostPort), + name = Some(name), + endpoint = Some(receiverEndpoint)) + receiverTrackingInfos.put(streamId, receiverTrackingInfo) + listenerBus.post(StreamingListenerReceiverStarted(receiverTrackingInfo.toReceiverInfo)) logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) true } @@ -203,21 +262,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Deregister a receiver */ private def deregisterReceiver(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val lastErrorTime = + if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() + val errorInfo = ReceiverErrorInfo( + lastErrorMessage = message, lastError = error, lastErrorTime = lastErrorTime) + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - val lastErrorTime = - if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() - oldInfo.copy(endpoint = null, active = false, lastErrorMessage = message, - lastError = error, lastErrorTime = lastErrorTime) + oldInfo.copy(state = ReceiverState.INACTIVE, errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - val lastErrorTime = - if (error == null || error == "") -1 else ssc.scheduler.clock.getTimeMillis() - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, - lastError = error, lastErrorTime = lastErrorTime) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo -= streamId - listenerBus.post(StreamingListenerReceiverStopped(newReceiverInfo)) + receiverTrackingInfos -= streamId + listenerBus.post(StreamingListenerReceiverStopped(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -228,9 +286,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Update a receiver's maximum ingestion rate */ def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { - for (info <- receiverInfo.get(streamUID); eP <- Option(info.endpoint)) { - eP.send(UpdateRateLimit(newRate)) - } + endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) } /** Add new blocks for the given stream */ @@ -240,16 +296,21 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** Report error sent by a receiver */ private def reportError(streamId: Int, message: String, error: String) { - val newReceiverInfo = receiverInfo.get(streamId) match { + val newReceiverTrackingInfo = receiverTrackingInfos.get(streamId) match { case Some(oldInfo) => - oldInfo.copy(lastErrorMessage = message, lastError = error) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = oldInfo.errorInfo.map(_.lastErrorTime).getOrElse(-1L)) + oldInfo.copy(errorInfo = Some(errorInfo)) case None => logWarning("No prior receiver info") - ReceiverInfo(streamId, "", null, false, "", lastErrorMessage = message, - lastError = error, lastErrorTime = ssc.scheduler.clock.getTimeMillis()) + val errorInfo = ReceiverErrorInfo(lastErrorMessage = message, lastError = error, + lastErrorTime = ssc.scheduler.clock.getTimeMillis()) + ReceiverTrackingInfo( + streamId, ReceiverState.INACTIVE, None, None, None, None, Some(errorInfo)) } - receiverInfo(streamId) = newReceiverInfo - listenerBus.post(StreamingListenerReceiverError(receiverInfo(streamId))) + + receiverTrackingInfos(streamId) = newReceiverTrackingInfo + listenerBus.post(StreamingListenerReceiverError(newReceiverTrackingInfo.toReceiverInfo)) val messageWithError = if (error != null && !error.isEmpty) { s"$message - $error" } else { @@ -258,171 +319,242 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logWarning(s"Error reported by receiver for stream $streamId: $messageWithError") } + private def scheduleReceiver(receiverId: Int): Seq[String] = { + val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None) + val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + receiverId, preferredLocation, receiverTrackingInfos, getExecutors) + updateReceiverScheduledExecutors(receiverId, scheduledExecutors) + scheduledExecutors + } + + private def updateReceiverScheduledExecutors( + receiverId: Int, scheduledExecutors: Seq[String]): Unit = { + val newReceiverTrackingInfo = receiverTrackingInfos.get(receiverId) match { + case Some(oldInfo) => + oldInfo.copy(state = ReceiverState.SCHEDULED, + scheduledExecutors = Some(scheduledExecutors)) + case None => + ReceiverTrackingInfo( + receiverId, + ReceiverState.SCHEDULED, + Some(scheduledExecutors), + runningExecutor = None) + } + receiverTrackingInfos.put(receiverId, newReceiverTrackingInfo) + } + /** Check if any blocks are left to be processed */ def hasUnallocatedBlocks: Boolean = { receivedBlockTracker.hasUnallocatedReceivedBlocks } + /** + * Get the list of executors excluding driver + */ + private def getExecutors: Seq[String] = { + if (ssc.sc.isLocal) { + Seq(ssc.sparkContext.env.blockManager.blockManagerId.hostPort) + } else { + ssc.sparkContext.env.blockManager.master.getMemoryStatus.filter { case (blockManagerId, _) => + blockManagerId.executorId != SparkContext.DRIVER_IDENTIFIER // Ignore the driver location + }.map { case (blockManagerId, _) => blockManagerId.hostPort }.toSeq + } + } + + /** + * Run the dummy Spark job to ensure that all slaves have registered. This avoids all the + * receivers to be scheduled on the same node. + * + * TODO Should poll the executor number and wait for executors according to + * "spark.scheduler.minRegisteredResourcesRatio" and + * "spark.scheduler.maxRegisteredResourcesWaitingTime" rather than running a dummy job. + */ + private def runDummySparkJob(): Unit = { + if (!ssc.sparkContext.isLocal) { + ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() + } + assert(getExecutors.nonEmpty) + } + + /** + * Get the receivers from the ReceiverInputDStreams, distributes them to the + * worker nodes as a parallel collection, and runs them. + */ + private def launchReceivers(): Unit = { + val receivers = receiverInputStreams.map(nis => { + val rcvr = nis.getReceiver() + rcvr.setReceiverId(nis.id) + rcvr + }) + + runDummySparkJob() + + logInfo("Starting " + receivers.length + " receivers") + endpoint.send(StartAllReceivers(receivers)) + } + + /** Check if tracker has been marked for starting */ + private def isTrackerStarted: Boolean = trackerState == Started + + /** Check if tracker has been marked for stopping */ + private def isTrackerStopping: Boolean = trackerState == Stopping + + /** Check if tracker has been marked for stopped */ + private def isTrackerStopped: Boolean = trackerState == Stopped + /** RpcEndpoint to receive messages from the receivers. */ private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged + private val submitJobThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + override def receive: PartialFunction[Any, Unit] = { + // Local messages + case StartAllReceivers(receivers) => + val scheduledExecutors = schedulingPolicy.scheduleReceivers(receivers, getExecutors) + for (receiver <- receivers) { + val executors = scheduledExecutors(receiver.streamId) + updateReceiverScheduledExecutors(receiver.streamId, executors) + receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation + startReceiver(receiver, executors) + } + case RestartReceiver(receiver) => + val scheduledExecutors = schedulingPolicy.rescheduleReceiver( + receiver.streamId, + receiver.preferredLocation, + receiverTrackingInfos, + getExecutors) + updateReceiverScheduledExecutors(receiver.streamId, scheduledExecutors) + startReceiver(receiver, scheduledExecutors) + case c: CleanupOldBlocks => + receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c)) + case UpdateReceiverRateLimit(streamUID, newRate) => + for (info <- receiverTrackingInfos.get(streamUID); eP <- info.endpoint) { + eP.send(UpdateRateLimit(newRate)) + } + // Remote messages case ReportError(streamId, message, error) => reportError(streamId, message, error) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterReceiver(streamId, typ, host, receiverEndpoint) => + // Remote messages + case RegisterReceiver(streamId, typ, hostPort, receiverEndpoint) => val successful = - registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + registerReceiver(streamId, typ, hostPort, receiverEndpoint, context.sender.address) context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) + // Local messages + case AllReceiverIds => + context.reply(receiverTrackingInfos.keys.toSeq) case StopAllReceivers => assert(isTrackerStopping || isTrackerStopped) stopReceivers() context.reply(true) } - /** Send stop signal to the receivers. */ - private def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.endpoint)} - .foreach { _.send(StopReceiver) } - logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") - } - } - - /** This thread class runs all the receivers on the cluster. */ - class ReceiverLauncher { - @transient val env = ssc.env - @volatile @transient var running = false - @transient val thread = new Thread() { - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverLauncher interrupted") - } - } - } - - def start() { - thread.start() - } - /** - * Get the list of executors excluding driver - */ - private def getExecutors(ssc: StreamingContext): List[String] = { - val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(":")(0)).toList - val driver = ssc.sparkContext.getConf.get("spark.driver.host") - executors.diff(List(driver)) - } - - /** Set host location(s) for each receiver so as to distribute them over - * executors in a round-robin fashion taking into account preferredLocation if set + * Start a receiver along with its scheduled executors */ - private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]], - executors: List[String]): Array[ArrayBuffer[String]] = { - val locations = new Array[ArrayBuffer[String]](receivers.length) - var i = 0 - for (i <- 0 until receivers.length) { - locations(i) = new ArrayBuffer[String]() - if (receivers(i).preferredLocation.isDefined) { - locations(i) += receivers(i).preferredLocation.get - } + private def startReceiver(receiver: Receiver[_], scheduledExecutors: Seq[String]): Unit = { + val receiverId = receiver.streamId + if (!isTrackerStarted) { + onReceiverJobFinish(receiverId) + return } - var count = 0 - for (i <- 0 until max(receivers.length, executors.length)) { - if (!receivers(i % receivers.length).preferredLocation.isDefined) { - locations(i % receivers.length) += executors(count) - count += 1 - if (count == executors.length) { - count = 0 - } - } - } - locations - } - - /** - * Get the receivers from the ReceiverInputDStreams, distributes them to the - * worker nodes as a parallel collection, and runs them. - */ - private def startReceivers() { - val receivers = receiverInputStreams.map(nis => { - val rcvr = nis.getReceiver() - rcvr.setReceiverId(nis.id) - rcvr - }) val checkpointDirOption = Option(ssc.checkpointDir) val serializableHadoopConf = new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) // Function to start the receiver on the worker node - val startReceiver = (iterator: Iterator[Receiver[_]]) => { - if (!iterator.hasNext) { - throw new SparkException( - "Could not start receiver as object not found.") - } - val receiver = iterator.next() - val supervisor = new ReceiverSupervisorImpl( - receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) - supervisor.start() - supervisor.awaitTermination() - } - - // Run the dummy Spark job to ensure that all slaves have registered. - // This avoids all the receivers to be scheduled on the same node. - if (!ssc.sparkContext.isLocal) { - ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() - } + val startReceiverFunc = new StartReceiverFunc(checkpointDirOption, serializableHadoopConf) - // Get the list of executors and schedule receivers - val executors = getExecutors(ssc) - val tempRDD = - if (!executors.isEmpty) { - val locations = scheduleReceivers(receivers, executors) - val roundRobinReceivers = (0 until receivers.length).map(i => - (receivers(i), locations(i))) - ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers) + // Create the RDD using the scheduledExecutors to run the receiver in a Spark job + val receiverRDD: RDD[Receiver[_]] = + if (scheduledExecutors.isEmpty) { + ssc.sc.makeRDD(Seq(receiver), 1) } else { - ssc.sc.makeRDD(receivers, receivers.size) + ssc.sc.makeRDD(Seq(receiver -> scheduledExecutors)) } + receiverRDD.setName(s"Receiver $receiverId") + val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit]( + receiverRDD, startReceiverFunc, Seq(0), (_, _) => Unit, ()) + // We will keep restarting the receiver job until ReceiverTracker is stopped + future.onComplete { + case Success(_) => + if (!isTrackerStarted) { + onReceiverJobFinish(receiverId) + } else { + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + case Failure(e) => + if (!isTrackerStarted) { + onReceiverJobFinish(receiverId) + } else { + logError("Receiver has been stopped. Try to restart it.", e) + logInfo(s"Restarting Receiver $receiverId") + self.send(RestartReceiver(receiver)) + } + }(submitJobThreadPool) + logInfo(s"Receiver ${receiver.streamId} started") + } - // Distribute the receivers and start them - logInfo("Starting " + receivers.length + " receivers") - running = true - try { - ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) - logInfo("All of the receivers have been terminated") - } finally { - running = false - } + override def onStop(): Unit = { + submitJobThreadPool.shutdownNow() } /** - * Wait until the Spark job that runs the receivers is terminated, or return when - * `milliseconds` elapses + * Call when a receiver is terminated. It means we won't restart its Spark job. */ - def awaitTermination(milliseconds: Long): Unit = { - thread.join(milliseconds) + private def onReceiverJobFinish(receiverId: Int): Unit = { + receiverJobExitLatch.countDown() + receiverTrackingInfos.remove(receiverId).foreach { receiverTrackingInfo => + if (receiverTrackingInfo.state == ReceiverState.ACTIVE) { + logWarning(s"Receiver $receiverId exited but didn't deregister") + } + } } - } - /** Check if tracker has been marked for starting */ - private def isTrackerStarted(): Boolean = trackerState == Started + /** Send stop signal to the receivers. */ + private def stopReceivers() { + receiverTrackingInfos.values.flatMap(_.endpoint).foreach { _.send(StopReceiver) } + logInfo("Sent stop signal to all " + receiverTrackingInfos.size + " receivers") + } + } - /** Check if tracker has been marked for stopping */ - private def isTrackerStopping(): Boolean = trackerState == Stopping +} - /** Check if tracker has been marked for stopped */ - private def isTrackerStopped(): Boolean = trackerState == Stopped +/** + * Function to start the receiver on the worker node. Use a class instead of closure to avoid + * the serialization issue. + */ +private class StartReceiverFunc( + checkpointDirOption: Option[String], + serializableHadoopConf: SerializableConfiguration) + extends (Iterator[Receiver[_]] => Unit) with Serializable { + + override def apply(iterator: Iterator[Receiver[_]]): Unit = { + if (!iterator.hasNext) { + throw new SparkException( + "Could not start receiver as object not found.") + } + if (TaskContext.get().attemptNumber() == 0) { + val receiver = iterator.next() + assert(iterator.hasNext == false) + val supervisor = new ReceiverSupervisorImpl( + receiver, SparkEnv.get, serializableHadoopConf.value, checkpointDirOption) + supervisor.start() + supervisor.awaitTermination() + } else { + // It's restarted by TaskScheduler, but we want to reschedule it again. So exit it. + } + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala new file mode 100644 index 0000000000000..043ff4d0ff054 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.streaming.scheduler.ReceiverState._ + +private[streaming] case class ReceiverErrorInfo( + lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L) + +/** + * Class having information about a receiver. + * + * @param receiverId the unique receiver id + * @param state the current Receiver state + * @param scheduledExecutors the scheduled executors provided by ReceiverSchedulingPolicy + * @param runningExecutor the running executor if the receiver is active + * @param name the receiver name + * @param endpoint the receiver endpoint. It can be used to send messages to the receiver + * @param errorInfo the receiver error information if it fails + */ +private[streaming] case class ReceiverTrackingInfo( + receiverId: Int, + state: ReceiverState, + scheduledExecutors: Option[Seq[String]], + runningExecutor: Option[String], + name: Option[String] = None, + endpoint: Option[RpcEndpointRef] = None, + errorInfo: Option[ReceiverErrorInfo] = None) { + + def toReceiverInfo: ReceiverInfo = ReceiverInfo( + receiverId, + name.getOrElse(""), + state == ReceiverState.ACTIVE, + location = runningExecutor.getOrElse(""), + lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), + lastError = errorInfo.map(_.lastError).getOrElse(""), + lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) + ) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala new file mode 100644 index 0000000000000..93f920fdc71f1 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite + +class ReceiverSchedulingPolicySuite extends SparkFunSuite { + + val receiverSchedulingPolicy = new ReceiverSchedulingPolicy + + test("rescheduleReceiver: empty executors") { + val scheduledExecutors = + receiverSchedulingPolicy.rescheduleReceiver(0, None, Map.empty, executors = Seq.empty) + assert(scheduledExecutors === Seq.empty) + } + + test("rescheduleReceiver: receiver preferredLocation") { + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.INACTIVE, None, None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 0, Some("host1"), receiverTrackingInfoMap, executors = Seq("host2")) + assert(scheduledExecutors.toSet === Set("host1", "host2")) + } + + test("rescheduleReceiver: return all idle executors if more than 3 idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // host3 is idle + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1"))) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 1, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host3", "host4", "host5")) + } + + test("rescheduleReceiver: return 3 best options if less than 3 idle executors") { + val executors = Seq("host1", "host2", "host3", "host4", "host5") + // Weights: host1 = 1.5, host2 = 0.5, host3 = 1.0 + // host4 and host5 are idle + val receiverTrackingInfoMap = Map( + 0 -> ReceiverTrackingInfo(0, ReceiverState.ACTIVE, None, Some("host1")), + 1 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host2", "host3")), None), + 2 -> ReceiverTrackingInfo(1, ReceiverState.SCHEDULED, Some(Seq("host1", "host3")), None)) + val scheduledExecutors = receiverSchedulingPolicy.rescheduleReceiver( + 3, None, receiverTrackingInfoMap, executors) + assert(scheduledExecutors.toSet === Set("host2", "host4", "host5")) + } + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more receivers than executors") { + val receivers = (0 until 6).map(new DummyReceiver(_)) + val executors = (10000 until 10003).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 2 receivers running on each executor and each receiver has one executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + numReceiversOnExecutor(executors(0)) = numReceiversOnExecutor.getOrElse(executors(0), 0) + 1 + } + assert(numReceiversOnExecutor === executors.map(_ -> 2).toMap) + } + + + test("scheduleReceivers: " + + "schedule receivers evenly when there are more executors than receivers") { + val receivers = (0 until 3).map(new DummyReceiver(_)) + val executors = (10000 until 10006).map(port => s"localhost:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has two executors + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 2) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + } + + test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { + val receivers = (0 until 3).map(new DummyReceiver(_)) ++ + (3 until 6).map(new DummyReceiver(_, Some("localhost"))) + val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ + (10003 until 10006).map(port => s"localhost2:${port}") + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) + val numReceiversOnExecutor = mutable.HashMap[String, Int]() + // There should be 1 receiver running on each executor and each receiver has 1 executor + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.size == 1) + executors.foreach { l => + numReceiversOnExecutor(l) = numReceiversOnExecutor.getOrElse(l, 0) + 1 + } + } + assert(numReceiversOnExecutor === executors.map(_ -> 1).toMap) + // Make sure we schedule the receivers to their preferredLocations + val executorsForReceiversWithPreferredLocation = + scheduledExecutors.filter { case (receiverId, executors) => receiverId >= 3 }.flatMap(_._2) + // We can simply check the executor set because we only know each receiver only has 1 executor + assert(executorsForReceiversWithPreferredLocation.toSet === + (10000 until 10003).map(port => s"localhost:${port}").toSet) + } + + test("scheduleReceivers: return empty if no receiver") { + assert(receiverSchedulingPolicy.scheduleReceivers(Seq.empty, Seq("localhost:10000")).isEmpty) + } + + test("scheduleReceivers: return empty scheduled executors if no executors") { + val receivers = (0 until 3).map(new DummyReceiver(_)) + val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) + scheduledExecutors.foreach { case (receiverId, executors) => + assert(executors.isEmpty) + } + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index aadb7231757b8..e2159bd4f225d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -18,66 +18,18 @@ package org.apache.spark.streaming.scheduler import org.scalatest.concurrent.Eventually._ -import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import org.apache.spark.streaming._ + import org.apache.spark.SparkConf -import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming._ import org.apache.spark.streaming.receiver._ -import org.apache.spark.util.Utils -import org.apache.spark.streaming.dstream.InputDStream -import scala.reflect.ClassTag import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.storage.StorageLevel /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") val ssc = new StreamingContext(sparkConf, Milliseconds(100)) - val tracker = new ReceiverTracker(ssc) - val launcher = new tracker.ReceiverLauncher() - val executors: List[String] = List("0", "1", "2", "3") - - test("receiver scheduling - all or none have preferred location") { - - def parse(s: String): Array[Array[String]] = { - val outerSplit = s.split("\\|") - val loc = new Array[Array[String]](outerSplit.length) - var i = 0 - for (i <- 0 until outerSplit.length) { - loc(i) = outerSplit(i).split("\\,") - } - loc - } - - def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) { - val receivers = - if (preferredLocation) { - Array.tabulate(numReceivers)(i => new DummyReceiver(host = - Some(((i + 1) % executors.length).toString))) - } else { - Array.tabulate(numReceivers)(_ => new DummyReceiver) - } - val locations = launcher.scheduleReceivers(receivers, executors) - val expectedLocations = parse(allocation) - assert(locations.deep === expectedLocations.deep) - } - - testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0") - testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2") - testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0") - } - - test("receiver scheduling - some have preferred location") { - val numReceivers = 4; - val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")), - new DummyReceiver, new DummyReceiver, new DummyReceiver) - val locations = launcher.scheduleReceivers(receivers, executors) - assert(locations(0)(0) === "1") - assert(locations(1)(0) === "0") - assert(locations(2)(0) === "1") - assert(locations(0).length === 1) - assert(locations(3).length === 1) - } test("Receiver tracker - propagates rate limit") { object ReceiverStartedWaiter extends StreamingListener { @@ -134,19 +86,19 @@ private class RateLimitInputDStream(@transient ssc_ : StreamingContext) * @note It's necessary to be a top-level object, or else serialization would create another * one on the executor side and we won't be able to read its rate limit. */ -private object SingletonDummyReceiver extends DummyReceiver +private object SingletonDummyReceiver extends DummyReceiver(0) /** * Dummy receiver implementation */ -private class DummyReceiver(host: Option[String] = None) +private class DummyReceiver(receiverId: Int, host: Option[String] = None) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { - def onStart() { - } + setReceiverId(receiverId) - def onStop() { - } + override def onStart(): Unit = {} + + override def onStop(): Unit = {} override def preferredLocation: Option[String] = host } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 40dc1fb601bd0..0891309f956d2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -119,20 +119,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (600) // onReceiverStarted - val receiverInfoStarted = ReceiverInfo(0, "test", null, true, "localhost") + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (None) // onReceiverError - val receiverInfoError = ReceiverInfo(1, "test", null, true, "localhost") + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) listener.receiverInfo(2) should be (None) // onReceiverStopped - val receiverInfoStopped = ReceiverInfo(2, "test", null, true, "localhost") + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) From 2e7f99a004f08a42e86f6f603e4ba35cb52561c4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Mon, 27 Jul 2015 21:08:56 -0700 Subject: [PATCH 100/219] [SPARK-8195] [SPARK-8196] [SQL] udf next_day last_day next_day, returns next certain dayofweek. last_day, returns the last day of the month which given date belongs to. Author: Daoyuan Wang Closes #6986 from adrian-wang/udfnlday and squashes the following commits: ef7e3da [Daoyuan Wang] fix 02b3426 [Daoyuan Wang] address 2 comments dc69630 [Daoyuan Wang] address comments from rxin 8846086 [Daoyuan Wang] address comments from rxin d09bcce [Daoyuan Wang] multi fix 1a9de3d [Daoyuan Wang] function next_day and last_day --- .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/datetimeFunctions.scala | 72 +++++++++++++++++++ .../sql/catalyst/util/DateTimeUtils.scala | 46 ++++++++++++ .../expressions/DateExpressionsSuite.scala | 28 ++++++++ .../org/apache/spark/sql/functions.scala | 17 +++++ .../apache/spark/sql/DateFunctionsSuite.scala | 22 ++++++ 6 files changed, 188 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aa05f448d12bc..61ee6f6f71631 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -219,8 +219,10 @@ object FunctionRegistry { expression[DayOfYear]("dayofyear"), expression[DayOfMonth]("dayofmonth"), expression[Hour]("hour"), - expression[Month]("month"), + expression[LastDay]("last_day"), expression[Minute]("minute"), + expression[Month]("month"), + expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), expression[WeekOfYear]("weekofyear"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 9e55f0546e123..b00a1b26fa285 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -265,3 +265,75 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx }) } } + +/** + * Returns the last day of the month which the date belongs to. + */ +case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { + override def child: Expression = startDate + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = DateType + + override def prettyName: String = "last_day" + + override def nullSafeEval(date: Any): Any = { + val days = date.asInstanceOf[Int] + DateTimeUtils.getLastDayOfMonth(days) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd) => { + s"$dtu.getLastDayOfMonth($sd)" + }) + } +} + +/** + * Returns the first date which is later than startDate and named as dayOfWeek. + * For example, NextDay(2015-07-27, Sunday) would return 2015-08-02, which is the first + * sunday later than 2015-07-27. + */ +case class NextDay(startDate: Expression, dayOfWeek: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = dayOfWeek + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, dayOfW: Any): Any = { + val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) + if (dow == -1) { + null + } else { + val sd = start.asInstanceOf[Int] + DateTimeUtils.getNextDateForDayOfWeek(sd, dow) + } + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, dowS) => { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val dow = ctx.freshName("dow") + val genDow = if (right.foldable) { + val dowVal = DateTimeUtils.getDayOfWeekFromString( + dayOfWeek.eval(InternalRow.empty).asInstanceOf[UTF8String]) + s"int $dow = $dowVal;" + } else { + s"int $dow = $dtu.getDayOfWeekFromString($dowS);" + } + genDow + s""" + if ($dow == -1) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $dtu.getNextDateForDayOfWeek($sd, $dow); + } + """ + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 07412e73b6a5b..2e28fb9af9b65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -573,4 +573,50 @@ object DateTimeUtils { dayInYear - 334 } } + + /** + * Returns Day of week from String. Starting from Thursday, marked as 0. + * (Because 1970-01-01 is Thursday). + */ + def getDayOfWeekFromString(string: UTF8String): Int = { + val dowString = string.toString.toUpperCase + dowString match { + case "SU" | "SUN" | "SUNDAY" => 3 + case "MO" | "MON" | "MONDAY" => 4 + case "TU" | "TUE" | "TUESDAY" => 5 + case "WE" | "WED" | "WEDNESDAY" => 6 + case "TH" | "THU" | "THURSDAY" => 0 + case "FR" | "FRI" | "FRIDAY" => 1 + case "SA" | "SAT" | "SATURDAY" => 2 + case _ => -1 + } + } + + /** + * Returns the first date which is later than startDate and is of the given dayOfWeek. + * dayOfWeek is an integer ranges in [0, 6], and 0 is Thu, 1 is Fri, etc,. + */ + def getNextDateForDayOfWeek(startDate: Int, dayOfWeek: Int): Int = { + startDate + 1 + ((dayOfWeek - 1 - startDate) % 7 + 7) % 7 + } + + /** + * number of days in a non-leap year. + */ + private[this] val daysInNormalYear = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31) + + /** + * Returns last day of the month for the given date. The date is expressed in days + * since 1.1.1970. + */ + def getLastDayOfMonth(date: Int): Int = { + val dayOfMonth = getDayOfMonth(date) + val month = getMonth(date) + if (month == 2 && isLeapYear(getYear(date))) { + date + daysInNormalYear(month - 1) + 1 - dayOfMonth + } else { + date + daysInNormalYear(month - 1) - dayOfMonth + } + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index bdba6ce891386..4d2d33765a269 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,6 +22,7 @@ import java.text.SimpleDateFormat import java.util.Calendar import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{StringType, TimestampType, DateType} class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -246,4 +247,31 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("last_day") { + checkEvaluation(LastDay(Literal(Date.valueOf("2015-02-28"))), Date.valueOf("2015-02-28")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-03-27"))), Date.valueOf("2015-03-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-04-26"))), Date.valueOf("2015-04-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-05-25"))), Date.valueOf("2015-05-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-06-24"))), Date.valueOf("2015-06-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-07-23"))), Date.valueOf("2015-07-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-08-01"))), Date.valueOf("2015-08-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-09-02"))), Date.valueOf("2015-09-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-10-03"))), Date.valueOf("2015-10-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-11-04"))), Date.valueOf("2015-11-30")) + checkEvaluation(LastDay(Literal(Date.valueOf("2015-12-05"))), Date.valueOf("2015-12-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) + checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) + } + + test("next_day") { + checkEvaluation( + NextDay(Literal(Date.valueOf("2015-07-23")), Literal("Thu")), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-30"))) + checkEvaluation( + NextDay(Literal(Date.valueOf("2015-07-23")), Literal("THURSDAY")), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-30"))) + checkEvaluation( + NextDay(Literal(Date.valueOf("2015-07-23")), Literal("th")), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-30"))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cab3db609dd4b..d18558b510f0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2032,6 +2032,13 @@ object functions { */ def hour(columnName: String): Column = hour(Column(columnName)) + /** + * Returns the last day of the month which the given date belongs to. + * @group datetime_funcs + * @since 1.5.0 + */ + def last_day(e: Column): Column = LastDay(e.expr) + /** * Extracts the minutes as an integer from a given date/timestamp/string. * @group datetime_funcs @@ -2046,6 +2053,16 @@ object functions { */ def minute(columnName: String): Column = minute(Column(columnName)) + /** + * Returns the first date which is later than given date sd and named as dow. + * For example, `next_day('2015-07-27', "Sunday")` would return 2015-08-02, which is the + * first Sunday later than 2015-07-27. The parameter dayOfWeek could be 2-letter, 3-letter, + * or full name of the day of the week (e.g. Mo, tue, FRIDAY). + * @group datetime_funcs + * @since 1.5.0 + */ + def next_day(sd: Column, dayOfWeek: String): Column = NextDay(sd.expr, lit(dayOfWeek).expr) + /** * Extracts the seconds as an integer from a given date/timestamp/string. * @group datetime_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 9e80ae86920d9..ff1c7562dc4a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -184,4 +184,26 @@ class DateFunctionsSuite extends QueryTest { Row(15, 15, 15)) } + test("function last_day") { + val df1 = Seq((1, "2015-07-23"), (2, "2015-07-24")).toDF("i", "d") + val df2 = Seq((1, "2015-07-23 00:11:22"), (2, "2015-07-24 11:22:33")).toDF("i", "t") + checkAnswer( + df1.select(last_day(col("d"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + checkAnswer( + df2.select(last_day(col("t"))), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-07-31")))) + } + + test("function next_day") { + val df1 = Seq(("mon", "2015-07-23"), ("tuesday", "2015-07-20")).toDF("dow", "d") + val df2 = Seq(("th", "2015-07-23 00:11:22"), ("xx", "2015-07-24 11:22:33")).toDF("dow", "t") + checkAnswer( + df1.select(next_day(col("d"), "MONDAY")), + Seq(Row(Date.valueOf("2015-07-27")), Row(Date.valueOf("2015-07-27")))) + checkAnswer( + df2.select(next_day(col("t"), "th")), + Seq(Row(Date.valueOf("2015-07-30")), Row(null))) + } + } From 84da8792e2a99736edb6c94df7eda87915a8a476 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 21:41:15 -0700 Subject: [PATCH 101/219] [SPARK-9395][SQL] Create a SpecializedGetters interface to track all the specialized getters. As we are adding more and more specialized getters to more classes (coming soon ArrayData), this interface can help us prevent missing a method in some interfaces. Author: Reynold Xin Closes #7713 from rxin/SpecializedGetters and squashes the following commits: 3b39be1 [Reynold Xin] Added override modifier. 567ba9c [Reynold Xin] [SPARK-9395][SQL] Create a SpecializedGetters interface to track all the specialized getters. --- .../expressions/SpecializedGetters.java | 53 +++++++++++++++++++ .../spark/sql/catalyst/InternalRow.scala | 30 ++++++----- 2 files changed, 69 insertions(+), 14 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java new file mode 100644 index 0000000000000..5f28d52a94bd7 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.Interval; +import org.apache.spark.unsafe.types.UTF8String; + +public interface SpecializedGetters { + + boolean isNullAt(int ordinal); + + boolean getBoolean(int ordinal); + + byte getByte(int ordinal); + + short getShort(int ordinal); + + int getInt(int ordinal); + + long getLong(int ordinal); + + float getFloat(int ordinal); + + double getDouble(int ordinal); + + Decimal getDecimal(int ordinal); + + UTF8String getUTF8String(int ordinal); + + byte[] getBinary(int ordinal); + + Interval getInterval(int ordinal); + + InternalRow getStruct(int ordinal, int numFields); + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 9a11de3840ce2..e395a67434fa7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -26,7 +26,7 @@ import org.apache.spark.unsafe.types.{Interval, UTF8String} * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -abstract class InternalRow extends Serializable { +abstract class InternalRow extends Serializable with SpecializedGetters { def numFields: Int @@ -38,29 +38,30 @@ abstract class InternalRow extends Serializable { def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] - def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) + override def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) - def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) + override def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) - def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) + override def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) - def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) + override def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) - def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) + override def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) - def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) + override def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) - def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) + override def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) - def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) + override def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) - def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) + override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + override def getDecimal(ordinal: Int): Decimal = + getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) - def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType) + override def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType) // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString @@ -71,7 +72,8 @@ abstract class InternalRow extends Serializable { * @param ordinal position to get the struct from. * @param numFields number of fields the struct type has */ - def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null) + override def getStruct(ordinal: Int, numFields: Int): InternalRow = + getAs[InternalRow](ordinal, null) override def toString: String = s"[${this.mkString(",")}]" From 3bc7055e265ee5c75af8726579663cea0590f6c0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 22:04:54 -0700 Subject: [PATCH 102/219] Fixed a test failure. --- .../test/scala/org/apache/spark/sql/DateFunctionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index ff1c7562dc4a6..001fcd035c82a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -203,7 +203,7 @@ class DateFunctionsSuite extends QueryTest { Seq(Row(Date.valueOf("2015-07-27")), Row(Date.valueOf("2015-07-27")))) checkAnswer( df2.select(next_day(col("t"), "th")), - Seq(Row(Date.valueOf("2015-07-30")), Row(null))) + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) } } From 63a492b931765b1edd66624421d503f1927825ec Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Mon, 27 Jul 2015 22:47:31 -0700 Subject: [PATCH 103/219] [SPARK-8828] [SQL] Revert SPARK-5680 JIRA: https://issues.apache.org/jira/browse/SPARK-8828 Author: Yijie Shen Closes #7667 from yjshen/revert_combinesum_2 and squashes the following commits: c37ccb1 [Yijie Shen] add test case 8377214 [Yijie Shen] revert spark.sql.useAggregate2 to its default value e2305ac [Yijie Shen] fix bug - avg on decimal column 7cb0e95 [Yijie Shen] [wip] resolving bugs 1fadb5a [Yijie Shen] remove occurance 17c6248 [Yijie Shen] revert SPARK-5680 --- .../sql/catalyst/expressions/aggregates.scala | 70 ++----------------- .../sql/execution/GeneratedAggregate.scala | 41 +---------- .../spark/sql/execution/SparkStrategies.scala | 2 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 31 ++++++++ .../execution/HiveCompatibilitySuite.scala | 1 - ..._format-0-eff4ef3c207d14d5121368f294697964 | 0 ..._format-1-4a03c4328565c60ca99689239f07fb16 | 1 - 7 files changed, 37 insertions(+), 109 deletions(-) delete mode 100644 sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 delete mode 100644 sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 42343d4d8d79c..5d4b349b1597a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -404,7 +404,7 @@ case class Average(child: Expression) extends UnaryExpression with PartialAggreg // partialSum already increase the precision by 10 val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Sum(partialCount.toAttribute) + val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) SplitEvaluation( Cast(Divide(castedSum, castedCount), dataType), partialCount :: partialSum :: Nil) @@ -490,13 +490,13 @@ case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 case DecimalType.Fixed(_, _) => val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), + Cast(Sum(partialSum.toAttribute), dataType), partialSum :: Nil) case _ => val partialSum = Alias(Sum(child), "PartialSum")() SplitEvaluation( - CombineSum(partialSum.toAttribute), + Sum(partialSum.toAttribute), partialSum :: Nil) } } @@ -522,8 +522,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg private val sum = MutableLiteral(null, calcType) - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) override def update(input: InternalRow): Unit = { sum.update(addFunction, input) @@ -538,67 +537,6 @@ case class SumFunction(expr: Expression, base: AggregateExpression1) extends Agg } } -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} - -case class CombineSumFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private val sum = MutableLiteral(null, calcType) - - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - - override def update(input: InternalRow): Unit = { - val result = expr.eval(input) - // partial sum result can be null only when no input rows present - if(result != null) { - sum.update(addFunction, input) - } - } - - override def eval(input: InternalRow): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } - } -} - case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { def this() = this(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 5ad4691a5ca07..1cd1420480f03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -108,7 +108,7 @@ case class GeneratedAggregate( Add( Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType) - ) :: currentSum :: zero :: Nil) + ) :: currentSum :: Nil) val result = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -118,45 +118,6 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case cs @ CombineSum(expr) => - val calcType = - expr.dataType match { - case DecimalType.Fixed(p, s) => - DecimalType.bounded(p + 10, s) - case _ => - expr.dataType - } - - val currentSum = AttributeReference("currentSum", calcType, nullable = true)() - val initialValue = Literal.create(null, calcType) - - // Coalesce avoids double calculation... - // but really, common sub expression elimination would be better.... - val zero = Cast(Literal(0), calcType) - // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its - // UnscaledValue will be null if and only if x is null; helps with Average on decimals - val actualExpr = expr match { - case UnscaledValue(e) => e - case _ => expr - } - // partial sum result can be null only when no input rows present - val updateFunction = If( - IsNotNull(actualExpr), - Coalesce( - Add( - Coalesce(currentSum :: zero :: Nil), - Cast(expr, calcType)) :: currentSum :: zero :: Nil), - currentSum) - - val result = - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(currentSum, cs.dataType) - case _ => currentSum - } - - AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case m @ Max(expr) => val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal.create(null, expr.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 306bbfec624c0..d88a02298c00d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -201,7 +201,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { - case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true + case _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true // The generated set implementation is pretty limited ATM. case CollectHashSet(exprs) if exprs.size == 1 && Seq(IntegerType, LongType).contains(exprs.head.dataType) => true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 358e319476e83..42724ed766af5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -227,6 +227,37 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Seq(Row("1"), Row("2"))) } + test("SPARK-8828 sum should return null if all input values are null") { + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) + } + } + } + test("aggregation with codegen") { val originalValue = sqlContext.conf.codegenEnabled sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b12b3838e615c..ec959cb2194b0 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -822,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "udaf_covar_pop", "udaf_covar_samp", "udaf_histogram_numeric", - "udaf_number_format", "udf2", "udf5", "udf6", diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 b/sql/hive/src/test/resources/golden/udaf_number_format-0-eff4ef3c207d14d5121368f294697964 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 b/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 deleted file mode 100644 index c6f275a0db131..0000000000000 --- a/sql/hive/src/test/resources/golden/udaf_number_format-1-4a03c4328565c60ca99689239f07fb16 +++ /dev/null @@ -1 +0,0 @@ -0.0 NULL NULL NULL From 60f08c7c8775c0462b74bc65b41397be6eb24b6d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 22:51:15 -0700 Subject: [PATCH 104/219] [SPARK-9373][SQL] Support StructType in Tungsten projection This pull request updates GenerateUnsafeProjection to support StructType. If an input struct type is backed already by an UnsafeRow, GenerateUnsafeProjection copies the bytes directly into its buffer space without any conversion. However, if the input is not an UnsafeRow, GenerateUnsafeProjection runs the code generated recursively to convert the input into an UnsafeRow and then copies it into the buffer space. Also create a TungstenProject operator that projects data directly into UnsafeRow. Note that I'm not sure if this is the way we want to structure Unsafe+codegen operators, but we can defer that decision to follow-up pull requests. Author: Reynold Xin Closes #7689 from rxin/tungsten-struct-type and squashes the following commits: 9162f42 [Reynold Xin] Support IntervalType in UnsafeRow's getter. be9f377 [Reynold Xin] Fixed tests. 10c4b7c [Reynold Xin] Format generated code. 77e8d0e [Reynold Xin] Fixed NondeterministicSuite. ac4951d [Reynold Xin] Yay. ac203bf [Reynold Xin] More comments. 9f36216 [Reynold Xin] Updated comment. 6b781fe [Reynold Xin] Reset the change in DataFrameSuite. 525b95b [Reynold Xin] Merged with master, more documentation & test cases. 321859a [Reynold Xin] [SPARK-9373][SQL] Support StructType in Tungsten projection [WIP] --- .../sql/catalyst/expressions/UnsafeRow.java | 2 + .../expressions/UnsafeRowWriters.java | 48 +++++- .../catalyst/expressions/BoundAttribute.scala | 9 +- .../codegen/GenerateUnsafeProjection.scala | 162 ++++++++++++++++-- .../expressions/complexTypeCreator.scala | 94 ++++++++-- .../ArithmeticExpressionSuite.scala | 2 +- .../expressions/BitwiseFunctionsSuite.scala | 20 ++- .../expressions/ExpressionEvalHelper.scala | 26 ++- .../spark/sql/execution/SparkStrategies.scala | 9 +- .../spark/sql/execution/basicOperators.scala | 25 +++ .../spark/sql/DataFrameTungstenSuite.scala | 84 +++++++++ .../expression/NondeterministicSuite.scala | 2 +- 12 files changed, 430 insertions(+), 53 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index fb084dd13b620..955fb4226fc0e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -265,6 +265,8 @@ public Object get(int ordinal, DataType dataType) { return getBinary(ordinal); } else if (dataType instanceof StringType) { return getUTF8String(ordinal); + } else if (dataType instanceof IntervalType) { + return getInterval(ordinal); } else if (dataType instanceof StructType) { return getStruct(ordinal, ((StructType) dataType).size()); } else { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 0ba31d3b9b743..8fdd7399602d2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -81,6 +82,52 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) } } + /** + * Writer for struct type where the struct field is backed by an {@link UnsafeRow}. + * + * We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}. + * Non-UnsafeRow struct fields are handled directly in + * {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection} + * by generating the Java code needed to convert them into UnsafeRow. + */ + public static class StructWriter { + public static int getSize(InternalRow input) { + int numBytes = 0; + if (input instanceof UnsafeRow) { + numBytes = ((UnsafeRow) input).getSizeInBytes(); + } else { + // This is handled directly in GenerateUnsafeProjection. + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) { + int numBytes = 0; + final long offset = target.getBaseOffset() + cursor; + if (input instanceof UnsafeRow) { + final UnsafeRow row = (UnsafeRow) input; + numBytes = row.getSizeInBytes(); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the string to the variable length portion. + row.writeToMemory(target.getBaseObject(), offset); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + } else { + // This is handled directly in GenerateUnsafeProjection. + throw new UnsupportedOperationException(); + } + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + /** Writer for interval type. */ public static class IntervalWriter { @@ -96,5 +143,4 @@ public static int write(UnsafeRow target, int ordinal, int cursor, Interval inpu return 16; } } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 41a877f214e55..8304d4ccd47f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -50,7 +50,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case BinaryType => input.getBinary(ordinal) case IntervalType => input.getInterval(ordinal) case t: StructType => input.getStruct(ordinal, t.size) - case dataType => input.get(ordinal, dataType) + case _ => input.get(ordinal, dataType) } } } @@ -64,10 +64,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(dataType) + val value = ctx.getColumn("i", dataType, ordinal) s""" - boolean ${ev.isNull} = i.isNullAt($ordinal); - ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); + boolean ${ev.isNull} = i.isNullAt($ordinal); + $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 9d2161947b351..3e87f7285847c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -34,11 +34,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName + private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case t: AtomicType if !t.isInstanceOf[DecimalType] => true case _: IntervalType => true + case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true case _ => false } @@ -55,15 +57,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val ret = ev.primitive ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") - val bufferTerm = ctx.freshName("buffer") - ctx.addMutableState("byte[]", bufferTerm, s"$bufferTerm = new byte[64];") - val cursorTerm = ctx.freshName("cursor") - val numBytesTerm = ctx.freshName("numBytes") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + val numBytes = ctx.freshName("numBytes") - val exprs = expressions.map(_.gen(ctx)) + val exprs = expressions.zipWithIndex.map { case (e, i) => + e.dataType match { + case st: StructType => + createCodeForStruct(ctx, e.gen(ctx), st) + case _ => + e.gen(ctx) + } + } val allExprs = exprs.map(_.code).mkString("\n") - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) val additionalSize = expressions.zipWithIndex.map { case (e, i) => e.dataType match { case StringType => @@ -72,6 +81,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" case IntervalType => s" + (${exprs(i).isNull} ? 0 : 16)" + case _: StructType => + s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))" case _ => "" } }.mkString("") @@ -81,11 +92,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case dt if ctx.isPrimitiveType(dt) => s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" case StringType => - s"$cursorTerm += $StringWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" + s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" case BinaryType => - s"$cursorTerm += $BinaryWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" + s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" case IntervalType => - s"$cursorTerm += $IntervalWriter.write($ret, $i, $cursorTerm, ${exprs(i).primitive})" + s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") @@ -99,24 +112,139 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $allExprs - int $numBytesTerm = $fixedSize $additionalSize; - if ($numBytesTerm > $bufferTerm.length) { - $bufferTerm = new byte[$numBytesTerm]; + int $numBytes = $fixedSize $additionalSize; + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; } $ret.pointTo( - $bufferTerm, + $buffer, org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, ${expressions.size}, - $numBytesTerm); - int $cursorTerm = $fixedSize; - + $numBytes); + int $cursor = $fixedSize; $writers boolean ${ev.isNull} = false; """ } + /** + * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. + * + * This function also handles nested structs by recursively generating the code to do conversion. + * + * @param ctx code generation context + * @param input the input struct, identified by a [[GeneratedExpressionCode]] + * @param schema schema of the struct field + */ + // TODO: refactor createCode and this function to reduce code duplication. + private def createCodeForStruct( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + schema: StructType): GeneratedExpressionCode = { + + val isNull = input.isNull + val primitive = ctx.freshName("structConvert") + ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val cursor = ctx.freshName("cursor") + + val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { + case (dt, i) => dt match { + case st: StructType => + val nestedStructEv = GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + ) + createCodeForStruct(ctx, nestedStructEv, st) + case _ => + GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + ) + } + } + val allExprs = exprs.map(_.code).mkString("\n") + + val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) + val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => + dt match { + case StringType => + s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + case BinaryType => + s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + case IntervalType => + s" + (${ev.isNull} ? 0 : 16)" + case _: StructType => + s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case _ => "" + } + }.mkString("") + + val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => + val update = dt match { + case _ if ctx.isPrimitiveType(dt) => + s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}" + case StringType => + s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" + case BinaryType => + s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" + case IntervalType => + s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: $dt") + } + s""" + if (${exprs(i).isNull}) { + $primitive.setNullAt($i); + } else { + $update; + } + """ + }.mkString("\n ") + + // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, + // just copy the bytes directly into our buffer space without running any conversion. + // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from + // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow. + val tmp = ctx.freshName("tmp") + val numBytes = ctx.freshName("numBytes") + val code = s""" + |${input.code} + |if (!${input.isNull}) { + | Object $tmp = (Object) ${input.primitive}; + | if ($tmp instanceof UnsafeRow) { + | $primitive = (UnsafeRow) $tmp; + | } else { + | $allExprs + | + | int $numBytes = $fixedSize $additionalSize; + | if ($numBytes > $buffer.length) { + | $buffer = new byte[$numBytes]; + | } + | + | $primitive.pointTo( + | $buffer, + | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + | ${exprs.size}, + | $numBytes); + | int $cursor = $fixedSize; + | + | $writers + | } + |} + """.stripMargin + + GeneratedExpressionCode(code, isNull, primitive) + } + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -159,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } """ - logDebug(s"code for ${expressions.mkString(",")}:\n$code") + logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = compile(code) c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 119168fa59f15..d8c9087ff5380 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -104,18 +104,19 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ }.mkString("\n") } override def prettyName: String = "struct" } + /** * Creates a struct with the given field names and values * @@ -168,14 +169,83 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { valExprs.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" - if (${eval.isNull}) { - ${ev.primitive}.update($i, null); - } else { - ${ev.primitive}.update($i, ${eval.primitive}); - } - """ + if (${eval.isNull}) { + ${ev.primitive}.update($i, null); + } else { + ${ev.primitive}.update($i, ${eval.primitive}); + } + """ }.mkString("\n") } override def prettyName: String = "named_struct" } + +/** + * Returns a Row containing the evaluation of all children expressions. This is a variant that + * returns UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + */ +case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override lazy val resolved: Boolean = childrenResolved + + override lazy val dataType: StructType = { + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) + } + } + StructType(fields) + } + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, children) + } + + override def prettyName: String = "struct_unsafe" +} + + +/** + * Creates a struct with the given field names and values. This is a variant that returns + * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with + * this expression automatically at runtime. + * + * @param children Seq(name1, val1, name2, val2, ...) + */ +case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression { + + private lazy val (nameExprs, valExprs) = + children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip + + private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + + override lazy val dataType: StructType = { + val fields = names.zip(valExprs).map { case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + } + StructType(fields) + } + + override def foldable: Boolean = valExprs.forall(_.foldable) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + GenerateUnsafeProjection.createCode(ctx, ev, valExprs) + } + + override def prettyName: String = "named_struct_unsafe" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e7e5231d32c9e..7773e098e0caa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -170,6 +170,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Pmod(-7, 3), 2) checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005) checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1)) - checkEvaluation(Pmod(2L, Long.MaxValue), 2) + checkEvaluation(Pmod(2L, Long.MaxValue), 2L) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala index 648fbf5a4c30b..fa30fbe528479 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseFunctionsSuite.scala @@ -30,8 +30,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, ~1.toByte) - check(1000.toShort, ~1000.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, (~1.toByte).toByte) + check(1000.toShort, (~1000.toShort).toShort) check(1000000, ~1000000) check(123456789123L, ~123456789123L) @@ -45,8 +46,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte & 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort & 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte & 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort & 2.toShort).toShort) check(1000000, 4, 1000000 & 4) check(123456789123L, 5L, 123456789123L & 5L) @@ -63,8 +65,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte | 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort | 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte | 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort | 2.toShort).toShort) check(1000000, 4, 1000000 | 4) check(123456789123L, 5L, 123456789123L | 5L) @@ -81,8 +84,9 @@ class BitwiseFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, expected) } - check(1.toByte, 2.toByte, 1.toByte ^ 2.toByte) - check(1000.toShort, 2.toShort, 1000.toShort ^ 2.toShort) + // Need the extra toByte even though IntelliJ thought it's not needed. + check(1.toByte, 2.toByte, (1.toByte ^ 2.toByte).toByte) + check(1000.toShort, 2.toShort, (1000.toShort ^ 2.toShort).toShort) check(1000000, 4, 1000000 ^ 4) check(123456789123L, 5L, 123456789123L ^ 5L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index ab0cdc857c80e..136368bf5b368 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -114,7 +114,7 @@ trait ExpressionEvalHelper { val actual = plan(inputRow).get(0, expression.dataType) if (!checkResult(actual, expected)) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + fail(s"Incorrect evaluation: $expression, actual: $actual, expected: $expected$input") } } @@ -146,7 +146,8 @@ trait ExpressionEvalHelper { if (actual != expectedRow) { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + fail("Incorrect Evaluation in codegen mode: " + + s"$expression, actual: $actual, expected: $expectedRow$input") } if (actual.copy() != expectedRow) { fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") @@ -163,12 +164,21 @@ trait ExpressionEvalHelper { expression) val unsafeRow = plan(inputRow) - // UnsafeRow cannot be compared with GenericInternalRow directly - val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) - val expectedRow = InternalRow(expected) - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!unsafeRow.isNullAt(0)) { + val expectedRow = InternalRow(expected) + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } + } else { + val lit = InternalRow(expected) + val expectedRow = UnsafeProjection.create(Array(expression.dataType)).apply(lit) + if (unsafeRow != expectedRow) { + fail("Incorrect evaluation in unsafe mode: " + + s"$expression, actual: $unsafeRow, expected: $expectedRow$input") + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d88a02298c00d..314b85f126dd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -363,7 +363,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Sort(sortExprs, global, child) => getSortOperator(sortExprs, global, planLater(child)):: Nil case logical.Project(projectList, child) => - execution.Project(projectList, planLater(child)) :: Nil + // If unsafe mode is enabled and we support these data types in Unsafe, use the + // Tungsten project. Otherwise, use the normal project. + if (sqlContext.conf.unsafeEnabled && + UnsafeProjection.canSupport(projectList) && UnsafeProjection.canSupport(child.schema)) { + execution.TungstenProject(projectList, planLater(child)) :: Nil + } else { + execution.Project(projectList, planLater(child)) :: Nil + } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fe429d862a0a3..b02e60dc85cdd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -49,6 +49,31 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends override def outputOrdering: Seq[SortOrder] = child.outputOrdering } + +/** + * A variant of [[Project]] that returns [[UnsafeRow]]s. + */ +case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { + + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + override def output: Seq[Attribute] = projectList.map(_.toAttribute) + + protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => + this.transformAllExpressions { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + } + val project = UnsafeProjection.create(projectList, child.output) + iter.map(project) + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + + /** * :: DeveloperApi :: */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala new file mode 100644 index 0000000000000..bf8ef9a97bc60 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ + +/** + * An end-to-end test suite specifically for testing Tungsten (Unsafe/CodeGen) mode. + * + * This is here for now so I can make sure Tungsten project is tested without refactoring existing + * end-to-end test infra. In the long run this should just go away. + */ +class DataFrameTungstenSuite extends QueryTest with SQLTestUtils { + + override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ + + test("test simple types") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val df = sqlContext.sparkContext.parallelize(Seq((1, 2))).toDF("a", "b") + assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2)) + } + } + + test("test struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val struct = Row(1, 2L, 3.0F, 3.0) + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, struct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(struct)) + } + } + + test("test nested struct type") { + withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") { + val innerStruct = Row(1, "abcd") + val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg") + val data = sqlContext.sparkContext.parallelize(Seq(Row(1, outerStruct))) + + val schema = new StructType() + .add("a", IntegerType) + .add("b", + new StructType() + .add("b1", IntegerType) + .add("b2", LongType) + .add("b3", FloatType) + .add("b4", DoubleType) + .add("b5", new StructType() + .add("b5a", IntegerType) + .add("b5b", StringType)) + .add("b6", StringType)) + + val df = sqlContext.createDataFrame(data, schema) + assert(df.select("b").first() === Row(outerStruct)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala index 99e11fd64b2b9..1c5a2ed2c0a53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.execution.expressions.{SparkPartitionID, Monotonical class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("MonotonicallyIncreasingID") { - checkEvaluation(MonotonicallyIncreasingID(), 0) + checkEvaluation(MonotonicallyIncreasingID(), 0L) } test("SparkPartitionID") { From 9c5612f4e197dec82a5eac9542896d6216a866b7 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 27 Jul 2015 23:02:23 -0700 Subject: [PATCH 105/219] [MINOR] [SQL] Support mutable expression unit test with codegen projection This is actually contains 3 minor issues: 1) Enable the unit test(codegen) for mutable expressions (FormatNumber, Regexp_Replace/Regexp_Extract) 2) Use the `PlatformDependent.copyMemory` instead of the `System.arrayCopy` Author: Cheng Hao Closes #7566 from chenghao-intel/codegen_ut and squashes the following commits: 24f43ea [Cheng Hao] enable codegen for mutable expression & UTF8String performance --- .../expressions/stringOperations.scala | 1 - .../spark/sql/StringFunctionsSuite.scala | 34 ++++++++++++++----- .../apache/spark/unsafe/types/UTF8String.java | 32 ++++++++--------- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 38b0fb37dee3b..edfffbc01c7b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -777,7 +777,6 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType) override def dataType: DataType = IntegerType - protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 0f9c986f649a1..8e0ea76d15881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -57,19 +57,27 @@ class StringFunctionsSuite extends QueryTest { } test("string regex_replace / regex_extract") { - val df = Seq(("100-200", "")).toDF("a", "b") + val df = Seq( + ("100-200", "(\\d+)-(\\d+)", "300"), + ("100-200", "(\\d+)-(\\d+)", "400"), + ("100-200", "(\\d+)", "400")).toDF("a", "b", "c") checkAnswer( df.select( regexp_replace($"a", "(\\d+)", "num"), regexp_extract($"a", "(\\d+)-(\\d+)", 1)), - Row("num-num", "100")) - - checkAnswer( - df.selectExpr( - "regexp_replace(a, '(\\d+)', 'num')", - "regexp_extract(a, '(\\d+)-(\\d+)', 2)"), - Row("num-num", "200")) + Row("num-num", "100") :: Row("num-num", "100") :: Row("num-num", "100") :: Nil) + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection followed by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + checkAnswer( + df.filter("isnotnull(a)").selectExpr( + "regexp_replace(a, b, c)", + "regexp_extract(a, b, 1)"), + Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil) } test("string ascii function") { @@ -290,5 +298,15 @@ class StringFunctionsSuite extends QueryTest { df.selectExpr("format_number(e, g)"), // decimal type of the 2nd argument is unacceptable Row("5.0000")) } + + // for testing the mutable state of the expression in code gen. + // This is a hack way to enable the codegen, thus the codegen is enable by default, + // it will still use the interpretProjection if projection follows by a LocalRelation, + // hence we add a filter operator. + // See the optimizer rule `ConvertToLocalRelation` + val df2 = Seq((5L, 4), (4L, 3), (3L, 2)).toDF("a", "b") + checkAnswer( + df2.filter("b>0").selectExpr("format_number(a, b)"), + Row("5.0000") :: Row("4.000") :: Row("3.00") :: Nil) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 85381cf0ef425..3e1cc67dbf337 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -300,13 +300,13 @@ public UTF8String trimRight() { } public UTF8String reverse() { - byte[] bytes = getBytes(); - byte[] result = new byte[bytes.length]; + byte[] result = new byte[this.numBytes]; int i = 0; // position in byte while (i < numBytes) { int len = numBytesForFirstByte(getByte(i)); - System.arraycopy(bytes, i, result, result.length - i - len, len); + copyMemory(this.base, this.offset + i, result, + BYTE_ARRAY_OFFSET + result.length - i - len, len); i += len; } @@ -316,11 +316,11 @@ public UTF8String reverse() { public UTF8String repeat(int times) { if (times <=0) { - return fromBytes(new byte[0]); + return EMPTY_UTF8; } byte[] newBytes = new byte[numBytes * times]; - System.arraycopy(getBytes(), 0, newBytes, 0, numBytes); + copyMemory(this.base, this.offset, newBytes, BYTE_ARRAY_OFFSET, numBytes); int copied = 1; while (copied < times) { @@ -385,16 +385,15 @@ public UTF8String rpad(int len, UTF8String pad) { UTF8String remain = pad.substring(0, spaces - padChars * count); byte[] data = new byte[this.numBytes + pad.numBytes * count + remain.numBytes]; - System.arraycopy(getBytes(), 0, data, 0, this.numBytes); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET, this.numBytes); int offset = this.numBytes; int idx = 0; - byte[] padBytes = pad.getBytes(); while (idx < count) { - System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++idx; offset += pad.numBytes; } - System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); return UTF8String.fromBytes(data); } @@ -421,15 +420,14 @@ public UTF8String lpad(int len, UTF8String pad) { int offset = 0; int idx = 0; - byte[] padBytes = pad.getBytes(); while (idx < count) { - System.arraycopy(padBytes, 0, data, offset, pad.numBytes); + copyMemory(pad.base, pad.offset, data, BYTE_ARRAY_OFFSET + offset, pad.numBytes); ++idx; offset += pad.numBytes; } - System.arraycopy(remain.getBytes(), 0, data, offset, remain.numBytes); + copyMemory(remain.base, remain.offset, data, BYTE_ARRAY_OFFSET + offset, remain.numBytes); offset += remain.numBytes; - System.arraycopy(getBytes(), 0, data, offset, numBytes()); + copyMemory(this.base, this.offset, data, BYTE_ARRAY_OFFSET + offset, numBytes()); return UTF8String.fromBytes(data); } @@ -454,9 +452,9 @@ public static UTF8String concat(UTF8String... inputs) { int offset = 0; for (int i = 0; i < inputs.length; i++) { int len = inputs[i].numBytes; - PlatformDependent.copyMemory( + copyMemory( inputs[i].base, inputs[i].offset, - result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, + result, BYTE_ARRAY_OFFSET + offset, len); offset += len; } @@ -494,7 +492,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { for (int i = 0, j = 0; i < inputs.length; i++) { if (inputs[i] != null) { int len = inputs[i].numBytes; - PlatformDependent.copyMemory( + copyMemory( inputs[i].base, inputs[i].offset, result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, len); @@ -503,7 +501,7 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { j++; // Add separator if this is not the last input. if (j < numInputs) { - PlatformDependent.copyMemory( + copyMemory( separator.base, separator.offset, result, PlatformDependent.BYTE_ARRAY_OFFSET + offset, separator.numBytes); From d93ab93d673c5007a1edb90a424b451c91c8a285 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 27 Jul 2015 23:34:29 -0700 Subject: [PATCH 106/219] [SPARK-9335] [STREAMING] [TESTS] Make sure the test stream is deleted in KinesisBackedBlockRDDSuite KinesisBackedBlockRDDSuite should make sure delete the stream. Author: zsxwing Closes #7663 from zsxwing/fix-SPARK-9335 and squashes the following commits: f0e9154 [zsxwing] Revert "[HOTFIX] - Disable Kinesis tests due to rate limits" 71a4552 [zsxwing] Make sure the test stream is deleted --- .../streaming/kinesis/KinesisBackedBlockRDDSuite.scala | 7 +++++-- .../spark/streaming/kinesis/KinesisStreamSuite.scala | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index b2e2a4246dbd5..e81fb11e5959f 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.streaming.kinesis -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} -import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException} class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { @@ -65,6 +65,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll } override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.deleteStream() + } if (sc != null) { sc.stop() } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 4992b041765e9..f9c952b9468bb 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -59,7 +59,7 @@ class KinesisStreamSuite extends KinesisFunSuite } } - ignore("KinesisUtils API") { + test("KinesisUtils API") { ssc = new StreamingContext(sc, Seconds(1)) // Tests the API, does not actually test data receiving val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream", @@ -83,7 +83,7 @@ class KinesisStreamSuite extends KinesisFunSuite * you must have AWS credentials available through the default AWS provider chain, * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ - ignore("basic operation") { + testIfEnabled("basic operation") { val kinesisTestUtils = new KinesisTestUtils() try { kinesisTestUtils.createStream() From fc3bd96bc3e4a1a2a1eb9b982b3468abd137e395 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 27 Jul 2015 23:56:16 -0700 Subject: [PATCH 107/219] Closes #6836 since Round has already been implemented. From 15724fac569258d2a149507d8c767d0de0ae8306 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Jul 2015 00:52:26 -0700 Subject: [PATCH 108/219] [SPARK-9394][SQL] Handle parentheses in CodeFormatter. Our CodeFormatter currently does not handle parentheses, and as a result in code dump, we see code formatted this way: ``` foo( a, b, c) ``` With this patch, it is formatted this way: ``` foo( a, b, c) ``` Author: Reynold Xin Closes #7712 from rxin/codeformat-parentheses and squashes the following commits: c2b1c5f [Reynold Xin] Took square bracket out 3cfb174 [Reynold Xin] Code review feedback. 91f5bb1 [Reynold Xin] [SPARK-9394][SQL] Handle parentheses in CodeFormatter. --- .../expressions/codegen/CodeFormatter.scala | 8 ++--- .../codegen/CodeFormatterSuite.scala | 30 +++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 2087cc7f109bc..c98182c96b165 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen /** - * An utility class that indents a block of code based on the curly braces. - * + * An utility class that indents a block of code based on the curly braces and parentheses. * This is used to prettify generated code when in debug mode (or exceptions). * * Written by Matei Zaharia. @@ -35,11 +34,12 @@ private class CodeFormatter { private var indentString = "" private def addLine(line: String): Unit = { - val indentChange = line.count(_ == '{') - line.count(_ == '}') + val indentChange = + line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0) val newIndentLevel = math.max(0, indentLevel + indentChange) // Lines starting with '}' should be de-indented even if they contain '{' after; // in addition, lines ending with ':' are typically labels - val thisLineIndent = if (line.startsWith("}") || line.endsWith(":")) { + val thisLineIndent = if (line.startsWith("}") || line.startsWith(")") || line.endsWith(":")) { " " * (indentSize * (indentLevel - 1)) } else { indentString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 478702fea6146..46daa3eb8bf80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -73,4 +73,34 @@ class CodeFormatterSuite extends SparkFunSuite { |} """.stripMargin } + + testCase("if else on the same line") { + """ + |class A { + | if (c) {duh;} else {boo;} + |} + """.stripMargin + }{ + """ + |class A { + | if (c) {duh;} else {boo;} + |} + """.stripMargin + } + + testCase("function calls") { + """ + |foo( + |a, + |b, + |c) + """.stripMargin + }{ + """ + |foo( + | a, + | b, + | c) + """.stripMargin + } } From ac8c549e2fa9ff3451deb4c3e49d151eeac18acc Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Tue, 28 Jul 2015 15:57:21 +0100 Subject: [PATCH 109/219] [EC2] Cosmetic fix for usage of spark-ec2 --ebs-vol-num option The last line of the usage seems ugly. ``` $ spark-ec2 --help --ebs-vol-num=EBS_VOL_NUM Number of EBS volumes to attach to each node as /vol[x]. The volumes will be deleted when the instances terminate. Only possible on EBS-backed AMIs. EBS volumes are only attached if --ebs-vol-size > 0.Only support up to 8 EBS volumes. ``` After applying this patch: ``` $ spark-ec2 --help --ebs-vol-num=EBS_VOL_NUM Number of EBS volumes to attach to each node as /vol[x]. The volumes will be deleted when the instances terminate. Only possible on EBS-backed AMIs. EBS volumes are only attached if --ebs-vol-size > 0. Only support up to 8 EBS volumes. ``` As this is a trivial thing I didn't create JIRA for this. Author: Kenichi Maehashi Closes #7632 from kmaehashi/spark-ec2-cosmetic-fix and squashes the following commits: 526c118 [Kenichi Maehashi] cosmetic fix for spark-ec2 --ebs-vol-num option usage --- ec2/spark_ec2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 7c83d68e7993e..ccf922d9371fb 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -242,7 +242,7 @@ def parse_args(): help="Number of EBS volumes to attach to each node as /vol[x]. " + "The volumes will be deleted when the instances terminate. " + "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0." + + "EBS volumes are only attached if --ebs-vol-size > 0. " + "Only support up to 8 EBS volumes.") parser.add_option( "--placement-group", type="string", default=None, From 4af622c855a32b1846242a6dd38b252ca30c8b82 Mon Sep 17 00:00:00 2001 From: vinodkc Date: Tue, 28 Jul 2015 08:48:57 -0700 Subject: [PATCH 110/219] [SPARK-8919] [DOCUMENTATION, MLLIB] Added @since tags to mllib.recommendation Author: vinodkc Closes #7325 from vinodkc/add_since_mllib.recommendation and squashes the following commits: 93156f2 [vinodkc] Changed 0.8.0 to 0.9.1 c413350 [vinodkc] Added @since --- .../spark/mllib/recommendation/ALS.scala | 10 +++++ .../MatrixFactorizationModel.scala | 38 ++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 93290e6508529..56c549ef99cb7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -26,6 +26,7 @@ import org.apache.spark.storage.StorageLevel /** * A more compact class to represent a rating than Tuple3[Int, Int, Double]. + * @since 0.8.0 */ case class Rating(user: Int, product: Int, rating: Double) @@ -254,6 +255,7 @@ class ALS private ( /** * Top-level methods for calling Alternating Least Squares (ALS) matrix factorization. + * @since 0.8.0 */ object ALS { /** @@ -269,6 +271,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param seed random seed + * @since 0.9.1 */ def train( ratings: RDD[Rating], @@ -293,6 +296,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into + * @since 0.8.0 */ def train( ratings: RDD[Rating], @@ -315,6 +319,7 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) + * @since 0.8.0 */ def train(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double) : MatrixFactorizationModel = { @@ -331,6 +336,7 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) + * @since 0.8.0 */ def train(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { @@ -351,6 +357,7 @@ object ALS { * @param blocks level of parallelism to split computation into * @param alpha confidence parameter * @param seed random seed + * @since 0.8.1 */ def trainImplicit( ratings: RDD[Rating], @@ -377,6 +384,7 @@ object ALS { * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into * @param alpha confidence parameter + * @since 0.8.1 */ def trainImplicit( ratings: RDD[Rating], @@ -401,6 +409,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param alpha confidence parameter + * @since 0.8.1 */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { @@ -418,6 +427,7 @@ object ALS { * @param ratings RDD of (userID, productID, rating) pairs * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) + * @since 0.8.1 */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) : MatrixFactorizationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 43d219a49cf4e..261ca9cef0c5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -49,6 +49,7 @@ import org.apache.spark.storage.StorageLevel * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. + * @since 0.8.0 */ class MatrixFactorizationModel( val rank: Int, @@ -73,7 +74,9 @@ class MatrixFactorizationModel( } } - /** Predict the rating of one user for one product. */ + /** Predict the rating of one user for one product. + * @since 0.8.0 + */ def predict(user: Int, product: Int): Double = { val userVector = userFeatures.lookup(user).head val productVector = productFeatures.lookup(product).head @@ -111,6 +114,7 @@ class MatrixFactorizationModel( * * @param usersProducts RDD of (user, product) pairs. * @return RDD of Ratings. + * @since 0.9.0 */ def predict(usersProducts: RDD[(Int, Int)]): RDD[Rating] = { // Previously the partitions of ratings are only based on the given products. @@ -142,6 +146,7 @@ class MatrixFactorizationModel( /** * Java-friendly version of [[MatrixFactorizationModel.predict]]. + * @since 1.2.0 */ def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() @@ -157,6 +162,7 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the user. The score is an opaque value that indicates how strongly * recommended the product is. + * @since 1.1.0 */ def recommendProducts(user: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(userFeatures.lookup(user).head, productFeatures, num) @@ -173,6 +179,7 @@ class MatrixFactorizationModel( * by score, decreasing. The first returned is the one predicted to be most strongly * recommended to the product. The score is an opaque value that indicates how strongly * recommended the user is. + * @since 1.1.0 */ def recommendUsers(product: Int, num: Int): Array[Rating] = MatrixFactorizationModel.recommend(productFeatures.lookup(product).head, userFeatures, num) @@ -180,6 +187,20 @@ class MatrixFactorizationModel( protected override val formatVersion: String = "1.0" + /** + * Save this model to the given path. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[Loader.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + * @since 1.3.0 + */ override def save(sc: SparkContext, path: String): Unit = { MatrixFactorizationModel.SaveLoadV1_0.save(this, path) } @@ -191,6 +212,7 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a userID and an array of * rating objects which contains the same userId, recommended productID and a "score" in the * rating field. Semantics of score is same as recommendProducts API + * @since 1.4.0 */ def recommendProductsForUsers(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, userFeatures, productFeatures, num).map { @@ -208,6 +230,7 @@ class MatrixFactorizationModel( * @return [(Int, Array[Rating])] objects, where every tuple contains a productID and an array * of rating objects which contains the recommended userId, same productID and a "score" in the * rating field. Semantics of score is same as recommendUsers API + * @since 1.4.0 */ def recommendUsersForProducts(num: Int): RDD[(Int, Array[Rating])] = { MatrixFactorizationModel.recommendForAll(rank, productFeatures, userFeatures, num).map { @@ -218,6 +241,9 @@ class MatrixFactorizationModel( } } +/** + * @since 1.3.0 + */ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { import org.apache.spark.mllib.util.Loader._ @@ -292,6 +318,16 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] { } } + /** + * Load a model from the given path. + * + * The model should have been saved by [[Saveable.save]]. + * + * @param sc Spark context used for loading model files. + * @param path Path specifying the directory to which the model was saved. + * @return Model instance + * @since 1.3.0 + */ override def load(sc: SparkContext, path: String): MatrixFactorizationModel = { val (loadedClassName, formatVersion, _) = loadMetadata(sc, path) val classNameV1_0 = SaveLoadV1_0.thisClassName From 5a2330e546074013ef706ac09028626912ec5475 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Jul 2015 09:42:35 -0700 Subject: [PATCH 111/219] [SPARK-9402][SQL] Remove CodegenFallback from Abs / FormatNumber. Both expressions already implement code generation. Author: Reynold Xin Closes #7723 from rxin/abs-formatnum and squashes the following commits: 31ed765 [Reynold Xin] [SPARK-9402][SQL] Remove CodegenFallback from Abs / FormatNumber. --- .../org/apache/spark/sql/catalyst/expressions/arithmetic.scala | 3 +-- .../spark/sql/catalyst/expressions/stringOperations.scala | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b37f530ec6814..4ec866475f8b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -68,8 +68,7 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects @ExpressionDescription( usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", extended = "> SELECT _FUNC_('-1');\n1") -case class Abs(child: Expression) - extends UnaryExpression with ExpectsInputTypes with CodegenFallback { +case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index edfffbc01c7b0..6db4e19c24ed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -1139,7 +1139,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio * fractional part. */ case class FormatNumber(x: Expression, d: Expression) - extends BinaryExpression with ExpectsInputTypes with CodegenFallback { + extends BinaryExpression with ExpectsInputTypes { override def left: Expression = x override def right: Expression = d From c740bed17215a9608c9eb9d80ffdf0fcf72c3911 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Jul 2015 09:43:12 -0700 Subject: [PATCH 112/219] [SPARK-9373][SQL] follow up for StructType support in Tungsten projection. Author: Reynold Xin Closes #7720 from rxin/struct-followup and squashes the following commits: d9757f5 [Reynold Xin] [SPARK-9373][SQL] follow up for StructType support in Tungsten projection. --- .../expressions/UnsafeRowWriters.java | 6 +-- .../codegen/GenerateUnsafeProjection.scala | 40 +++++++++---------- .../spark/sql/execution/SparkStrategies.scala | 3 +- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 8fdd7399602d2..32faad374015c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -47,7 +47,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String in target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } - // Write the string to the variable length portion. + // Write the bytes to the variable length portion. input.writeToMemory(target.getBaseObject(), offset); // Set the fixed length portion. @@ -73,7 +73,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } - // Write the string to the variable length portion. + // Write the bytes to the variable length portion. ByteArray.writeToMemory(input, target.getBaseObject(), offset); // Set the fixed length portion. @@ -115,7 +115,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow i target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); } - // Write the string to the variable length portion. + // Write the bytes to the variable length portion. row.writeToMemory(target.getBaseObject(), offset); // Set the fixed length portion. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 3e87f7285847c..9a4c00e86a3ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -62,14 +62,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val cursor = ctx.freshName("cursor") val numBytes = ctx.freshName("numBytes") - val exprs = expressions.zipWithIndex.map { case (e, i) => - e.dataType match { - case st: StructType => - createCodeForStruct(ctx, e.gen(ctx), st) - case _ => - e.gen(ctx) - } - } + val exprs = expressions.map { e => e.dataType match { + case st: StructType => createCodeForStruct(ctx, e.gen(ctx), st) + case _ => e.gen(ctx) + }} val allExprs = exprs.map(_.code).mkString("\n") val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) @@ -153,20 +149,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => dt match { - case st: StructType => - val nestedStructEv = GeneratedExpressionCode( - code = "", - isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getColumn(input.primitive, dt, i)}" - ) - createCodeForStruct(ctx, nestedStructEv, st) - case _ => - GeneratedExpressionCode( - code = "", - isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getColumn(input.primitive, dt, i)}" - ) - } + case st: StructType => + val nestedStructEv = GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + ) + createCodeForStruct(ctx, nestedStructEv, st) + case _ => + GeneratedExpressionCode( + code = "", + isNull = s"${input.primitive}.isNullAt($i)", + primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + ) + } } val allExprs = exprs.map(_.code).mkString("\n") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 314b85f126dd2..f3ef066528ff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -339,7 +339,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * if necessary. */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) { + if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && + UnsafeExternalSort.supportsSchema(child.schema)) { execution.UnsafeExternalSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) From 9bbe0171cb434edb160fad30ea2d4221f525c919 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Jul 2015 09:43:39 -0700 Subject: [PATCH 113/219] [SPARK-8196][SQL] Fix null handling & documentation for next_day. The original patch didn't handle nulls correctly for next_day. Author: Reynold Xin Closes #7718 from rxin/next_day and squashes the following commits: 616a425 [Reynold Xin] Merged DatetimeExpressionsSuite into DateFunctionsSuite. faa78cf [Reynold Xin] Merged DatetimeFunctionsSuite into DateExpressionsSuite. 6c4fb6a [Reynold Xin] [SPARK-8196][SQL] Fix null handling & documentation for next_day. --- .../sql/catalyst/expressions/Expression.scala | 12 ++--- .../expressions/datetimeFunctions.scala | 46 ++++++++++------- .../sql/catalyst/expressions/literals.scala | 2 +- .../sql/catalyst/util/DateTimeUtils.scala | 2 +- .../expressions/DateExpressionsSuite.scala | 43 +++++++++++++--- .../expressions/DatetimeFunctionsSuite.scala | 37 -------------- .../expressions/ExpressionEvalHelper.scala | 1 + .../expressions/NonFoldableLiteral.scala | 50 +++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 20 +++++--- .../apache/spark/sql/DateFunctionsSuite.scala | 21 ++++++++ .../spark/sql/DatetimeExpressionsSuite.scala | 48 ------------------ 11 files changed, 158 insertions(+), 124 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index cb4c3f24b2721..03e36c7871bcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -355,9 +355,9 @@ abstract class BinaryExpression extends Expression { * @param f accepts two variable names and returns Java code to compute the output. */ protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s"${ev.primitive} = ${f(eval1, eval2)};" }) @@ -372,9 +372,9 @@ abstract class BinaryExpression extends Expression { * and returns Java code to compute the output. */ protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val resultCode = f(eval1.primitive, eval2.primitive) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index b00a1b26fa285..c37afc13f2d17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -276,8 +276,6 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC override def dataType: DataType = DateType - override def prettyName: String = "last_day" - override def nullSafeEval(date: Any): Any = { val days = date.asInstanceOf[Int] DateTimeUtils.getLastDayOfMonth(days) @@ -289,12 +287,16 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC s"$dtu.getLastDayOfMonth($sd)" }) } + + override def prettyName: String = "last_day" } /** * Returns the first date which is later than startDate and named as dayOfWeek. * For example, NextDay(2015-07-27, Sunday) would return 2015-08-02, which is the first - * sunday later than 2015-07-27. + * Sunday later than 2015-07-27. + * + * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]]. */ case class NextDay(startDate: Expression, dayOfWeek: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -318,22 +320,32 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (sd, dowS) => { - val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val dow = ctx.freshName("dow") - val genDow = if (right.foldable) { - val dowVal = DateTimeUtils.getDayOfWeekFromString( - dayOfWeek.eval(InternalRow.empty).asInstanceOf[UTF8String]) - s"int $dow = $dowVal;" - } else { - s"int $dow = $dtu.getDayOfWeekFromString($dowS);" - } - genDow + s""" - if ($dow == -1) { - ${ev.isNull} = true; + val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") + val dayOfWeekTerm = ctx.freshName("dayOfWeek") + if (dayOfWeek.foldable) { + val input = dayOfWeek.eval().asInstanceOf[UTF8String] + if ((input eq null) || DateTimeUtils.getDayOfWeekFromString(input) == -1) { + s""" + |${ev.isNull} = true; + """.stripMargin } else { - ${ev.primitive} = $dtu.getNextDateForDayOfWeek($sd, $dow); + val dayOfWeekValue = DateTimeUtils.getDayOfWeekFromString(input) + s""" + |${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekValue); + """.stripMargin } - """ + } else { + s""" + |int $dayOfWeekTerm = $dateTimeUtilClass.getDayOfWeekFromString($dowS); + |if ($dayOfWeekTerm == -1) { + | ${ev.isNull} = true; + |} else { + | ${ev.primitive} = $dateTimeUtilClass.getNextDateForDayOfWeek($sd, $dayOfWeekTerm); + |} + """.stripMargin + } }) } + + override def prettyName: String = "next_day" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 85060b7893556..064a1720c36e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -118,7 +118,7 @@ case class Literal protected (value: Any, dataType: DataType) super.genCode(ctx, ev) } else { ev.isNull = "false" - ev.primitive = s"${value}" + ev.primitive = s"${value}D" "" } case ByteType | ShortType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 2e28fb9af9b65..8b0b80c26db17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -575,7 +575,7 @@ object DateTimeUtils { } /** - * Returns Day of week from String. Starting from Thursday, marked as 0. + * Returns day of week from String. Starting from Thursday, marked as 0. * (Because 1970-01-01 is Thursday). */ def getDayOfWeekFromString(string: UTF8String): Int = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 4d2d33765a269..30c5769424bd7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -32,6 +32,19 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) val ts = new Timestamp(sdf.parse("2013-11-08 13:10:15").getTime) + test("datetime function current_date") { + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int] + val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) + } + + test("datetime function current_timestamp") { + val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long]) + val t1 = System.currentTimeMillis() + assert(math.abs(t1 - ct.getTime) < 5000) + } + test("DayOfYear") { val sdfDay = new SimpleDateFormat("D") (2002 to 2004).foreach { y => @@ -264,14 +277,28 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("next_day") { + def testNextDay(input: String, dayOfWeek: String, output: String): Unit = { + checkEvaluation( + NextDay(Literal(Date.valueOf(input)), NonFoldableLiteral(dayOfWeek)), + DateTimeUtils.fromJavaDate(Date.valueOf(output))) + checkEvaluation( + NextDay(Literal(Date.valueOf(input)), Literal(dayOfWeek)), + DateTimeUtils.fromJavaDate(Date.valueOf(output))) + } + testNextDay("2015-07-23", "Mon", "2015-07-27") + testNextDay("2015-07-23", "mo", "2015-07-27") + testNextDay("2015-07-23", "Tue", "2015-07-28") + testNextDay("2015-07-23", "tu", "2015-07-28") + testNextDay("2015-07-23", "we", "2015-07-29") + testNextDay("2015-07-23", "wed", "2015-07-29") + testNextDay("2015-07-23", "Thu", "2015-07-30") + testNextDay("2015-07-23", "TH", "2015-07-30") + testNextDay("2015-07-23", "Fri", "2015-07-24") + testNextDay("2015-07-23", "fr", "2015-07-24") + + checkEvaluation(NextDay(Literal(Date.valueOf("2015-07-23")), Literal("xx")), null) + checkEvaluation(NextDay(Literal.create(null, DateType), Literal("xx")), null) checkEvaluation( - NextDay(Literal(Date.valueOf("2015-07-23")), Literal("Thu")), - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-30"))) - checkEvaluation( - NextDay(Literal(Date.valueOf("2015-07-23")), Literal("THURSDAY")), - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-30"))) - checkEvaluation( - NextDay(Literal(Date.valueOf("2015-07-23")), Literal("th")), - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-30"))) + NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala deleted file mode 100644 index 1618c24871c60..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DatetimeFunctionsSuite.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.DateTimeUtils - -class DatetimeFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { - test("datetime function current_date") { - val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val cd = CurrentDate().eval(EmptyRow).asInstanceOf[Int] - val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) - } - - test("datetime function current_timestamp") { - val ct = DateTimeUtils.toJavaTimestamp(CurrentTimestamp().eval(EmptyRow).asInstanceOf[Long]) - val t1 = System.currentTimeMillis() - assert(math.abs(t1 - ct.getTime) < 5000) - } - -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 136368bf5b368..0c8611d5ddefa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -82,6 +82,7 @@ trait ExpressionEvalHelper { s""" |Code generation of $expression failed: |$e + |${e.getStackTraceString} """.stripMargin) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala new file mode 100644 index 0000000000000..0559fb80e7fce --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types._ + + +/** + * A literal value that is not foldable. Used in expression codegen testing to test code path + * that behave differently based on foldable values. + */ +case class NonFoldableLiteral(value: Any, dataType: DataType) + extends LeafExpression with CodegenFallback { + + override def foldable: Boolean = false + override def nullable: Boolean = true + + override def toString: String = if (value != null) value.toString else "null" + + override def eval(input: InternalRow): Any = value + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + Literal.create(value, dataType).genCode(ctx, ev) + } +} + + +object NonFoldableLiteral { + def apply(value: Any): NonFoldableLiteral = { + val lit = Literal(value) + NonFoldableLiteral(lit.value, lit.dataType) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d18558b510f0b..cec61b66b157c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2033,7 +2033,10 @@ object functions { def hour(columnName: String): Column = hour(Column(columnName)) /** - * Returns the last day of the month which the given date belongs to. + * Given a date column, returns the last day of the month which the given date belongs to. + * For example, input "2015-07-27" returns "2015-07-31" since July 31 is the last day of the + * month in July 2015. + * * @group datetime_funcs * @since 1.5.0 */ @@ -2054,14 +2057,19 @@ object functions { def minute(columnName: String): Column = minute(Column(columnName)) /** - * Returns the first date which is later than given date sd and named as dow. - * For example, `next_day('2015-07-27', "Sunday")` would return 2015-08-02, which is the - * first Sunday later than 2015-07-27. The parameter dayOfWeek could be 2-letter, 3-letter, - * or full name of the day of the week (e.g. Mo, tue, FRIDAY). + * Given a date column, returns the first date which is later than the value of the date column + * that is on the specified day of the week. + * + * For example, `next_day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first + * Sunday after 2015-07-27. + * + * Day of the week parameter is case insensitive, and accepts: + * "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". + * * @group datetime_funcs * @since 1.5.0 */ - def next_day(sd: Column, dayOfWeek: String): Column = NextDay(sd.expr, lit(dayOfWeek).expr) + def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr) /** * Extracts the seconds as an integer from a given date/timestamp/string. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 001fcd035c82a..36820cbbc7e5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.{Timestamp, Date} import java.text.SimpleDateFormat +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ class DateFunctionsSuite extends QueryTest { @@ -27,6 +28,26 @@ class DateFunctionsSuite extends QueryTest { import ctx.implicits._ + test("function current_date") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) + val d2 = DateTimeUtils.fromJavaDate( + ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) + val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) + assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) + } + + test("function current_timestamp") { + val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") + checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) + // Execution in one query should return the same value + checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), + Row(true)) + assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( + 0).getTime - System.currentTimeMillis()) < 5000) + } + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val sdfDate = new SimpleDateFormat("yyyy-MM-dd") val d = new Date(sdf.parse("2015-04-08 13:10:15").getTime) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala deleted file mode 100644 index 44b915304533c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatetimeExpressionsSuite.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.functions._ - -class DatetimeExpressionsSuite extends QueryTest { - private lazy val ctx = org.apache.spark.sql.test.TestSQLContext - - import ctx.implicits._ - - lazy val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") - - test("function current_date") { - val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0)) - val d2 = DateTimeUtils.fromJavaDate( - ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0)) - val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis()) - assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) - } - - test("function current_timestamp") { - checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) - // Execution in one query should return the same value - checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""), - Row(true)) - assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp( - 0).getTime - System.currentTimeMillis()) < 5000) - } - -} From 35ef853b3f9d955949c464e4a0d445147e0e9a07 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Tue, 28 Jul 2015 10:12:09 -0700 Subject: [PATCH 114/219] [SPARK-9397] DataFrame should provide an API to find source data files if applicable Certain applications would benefit from being able to inspect DataFrames that are straightforwardly produced by data sources that stem from files, and find out their source data. For example, one might want to display to a user the size of the data underlying a table, or to copy or mutate it. This PR exposes an `inputFiles` method on DataFrame which attempts to discover the source data in a best-effort manner, by inspecting HadoopFsRelations and JSONRelations. Author: Aaron Davidson Closes #7717 from aarondav/paths and squashes the following commits: ff67430 [Aaron Davidson] inputFiles 0acd3ad [Aaron Davidson] [SPARK-9397] DataFrame should provide an API to find source data files if applicable --- .../org/apache/spark/sql/DataFrame.scala | 20 +++++++++++++++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 20 +++++++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 6 +++--- 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 114ab91d10aa0..3ea0f9ed3bddd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -40,8 +40,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} -import org.apache.spark.sql.execution.datasources.CreateTableUsingAsSelect -import org.apache.spark.sql.json.JacksonGenerator +import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} +import org.apache.spark.sql.json.{JacksonGenerator, JSONRelation} +import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -1546,6 +1547,21 @@ class DataFrame private[sql]( } } + /** + * Returns a best-effort snapshot of the files that compose this DataFrame. This method simply + * asks each constituent BaseRelation for its respective files and takes the union of all results. + * Depending on the source relations, this may not find all input files. Duplicates are removed. + */ + def inputFiles: Array[String] = { + val files: Seq[String] = logicalPlan.collect { + case LogicalRelation(fsBasedRelation: HadoopFsRelation) => + fsBasedRelation.paths.toSeq + case LogicalRelation(jsonRelation: JSONRelation) => + jsonRelation.path.toSeq + }.flatten + files.toSet.toArray + } + //////////////////////////////////////////////////////////////////////////// // for Python API //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f67f2c60c0e16..3151e071b19ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -23,7 +23,10 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ +import org.apache.spark.sql.json.JSONRelation +import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} @@ -491,6 +494,23 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq) } + test("inputFiles") { + val fakeRelation1 = new ParquetRelation(Array("/my/path", "/my/other/path"), + Some(testData.schema), None, Map.empty)(sqlContext) + val df1 = DataFrame(sqlContext, LogicalRelation(fakeRelation1)) + assert(df1.inputFiles.toSet == fakeRelation1.paths.toSet) + + val fakeRelation2 = new JSONRelation("/json/path", 1, Some(testData.schema), sqlContext) + val df2 = DataFrame(sqlContext, LogicalRelation(fakeRelation2)) + assert(df2.inputFiles.toSet == fakeRelation2.path.toSet) + + val unionDF = df1.unionAll(df2) + assert(unionDF.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path) + + val filtered = df1.filter("false").unionAll(df2.intersect(df2)) + assert(filtered.inputFiles.toSet == fakeRelation1.paths.toSet ++ fakeRelation2.path) + } + ignore("show") { // This test case is intended ignored, but to make sure it compiles correctly testData.select($"*").show() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 3180c05445c9f..a8c9b4fa71b99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -274,9 +274,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging - // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to + // NOTE: Instead of passing Metastore schema directly to `ParquetRelation`, we have to // serialize the Metastore schema to JSON and pass it as a data source option because of the - // evil case insensitivity issue, which is reconciled within `ParquetRelation2`. + // evil case insensitivity issue, which is reconciled within `ParquetRelation`. val parquetOptions = Map( ParquetRelation.METASTORE_SCHEMA -> metastoreSchema.json, ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString) @@ -290,7 +290,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { cachedDataSourceTables.getIfPresent(tableIdentifier) match { case null => None // Cache miss - case logical@LogicalRelation(parquetRelation: ParquetRelation) => + case logical @ LogicalRelation(parquetRelation: ParquetRelation) => // If we have the same paths, same schema, and same partition spec, // we will use the cached Parquet Relation. val useCached = From 614323406225a3522ee601935ce3052449614145 Mon Sep 17 00:00:00 2001 From: trestletech Date: Tue, 28 Jul 2015 10:45:19 -0700 Subject: [PATCH 115/219] Use vector-friendly comparison for packages argument. Otherwise, `sparkR.init()` with multiple `sparkPackages` results in this warning: ``` Warning message: In if (packages != "") { : the condition has length > 1 and only the first element will be used ``` Author: trestletech Closes #7701 from trestletech/compare-packages and squashes the following commits: 72c8b36 [trestletech] Correct function name. c52db0e [trestletech] Added test for multiple packages. 3aab1a7 [trestletech] Use vector-friendly comparison for packages argument. --- R/pkg/R/client.R | 2 +- R/pkg/inst/tests/test_client.R | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 6f772158ddfe8..c811d1dac3bd5 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -48,7 +48,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack jars <- paste("--jars", jars) } - if (packages != "") { + if (!identical(packages, "")) { packages <- paste("--packages", packages) } diff --git a/R/pkg/inst/tests/test_client.R b/R/pkg/inst/tests/test_client.R index 30b05c1a2afcd..8a20991f89af8 100644 --- a/R/pkg/inst/tests/test_client.R +++ b/R/pkg/inst/tests/test_client.R @@ -30,3 +30,7 @@ test_that("no package specified doesn't add packages flag", { expect_equal(gsub("[[:space:]]", "", args), "") }) + +test_that("multiple packages don't produce a warning", { + expect_that(generateSparkSubmitArgs("", "", "", "", c("A", "B")), not(gives_warning())) +}) From 31ec6a871eebd2377961c5195f9c2bff3a899fba Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 28 Jul 2015 11:48:56 -0700 Subject: [PATCH 116/219] [SPARK-9327] [DOCS] Fix documentation about classpath config options. Author: Marcelo Vanzin Closes #7651 from vanzin/SPARK-9327 and squashes the following commits: 2923e23 [Marcelo Vanzin] [SPARK-9327] [docs] Fix documentation about classpath config options. --- docs/configuration.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 200f3cd212e46..fd236137cb96e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -203,7 +203,7 @@ Apart from these, the following properties are also available, and may be useful spark.driver.extraClassPath (none) - Extra classpath entries to append to the classpath of the driver. + Extra classpath entries to prepend to the classpath of the driver.
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. @@ -250,7 +250,7 @@ Apart from these, the following properties are also available, and may be useful spark.executor.extraClassPath (none) - Extra classpath entries to append to the classpath of executors. This exists primarily for + Extra classpath entries to prepend to the classpath of executors. This exists primarily for backwards-compatibility with older versions of Spark. Users typically should not need to set this option. From 6cdcc21fe654ac0a2d0d72783eb10005fc513af6 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 28 Jul 2015 13:16:48 -0700 Subject: [PATCH 117/219] [SPARK-9196] [SQL] Ignore test DatetimeExpressionsSuite: function current_timestamp. This test is flaky. https://issues.apache.org/jira/browse/SPARK-9196 will track the fix of it. For now, let's disable this test. Author: Yin Huai Closes #7727 from yhuai/SPARK-9196-ignore and squashes the following commits: f92bded [Yin Huai] Ignore current_timestamp. --- .../test/scala/org/apache/spark/sql/DateFunctionsSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 36820cbbc7e5e..07eb6e4a8d8cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -38,7 +38,8 @@ class DateFunctionsSuite extends QueryTest { assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1) } - test("function current_timestamp") { + // This is a bad test. SPARK-9196 will fix it and re-enable it. + ignore("function current_timestamp") { val df1 = Seq((1, 2), (3, 1)).toDF("a", "b") checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1)) // Execution in one query should return the same value From 8d5bb5283c3cc9180ef34b05be4a715d83073b1e Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 28 Jul 2015 14:16:57 -0700 Subject: [PATCH 118/219] [SPARK-9391] [ML] Support minus, dot, and intercept operators in SparkR RFormula Adds '.', '-', and intercept parsing to RFormula. Also splits RFormulaParser into a separate file. Umbrella design doc here: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit?usp=sharing mengxr Author: Eric Liang Closes #7707 from ericl/string-features-2 and squashes the following commits: 8588625 [Eric Liang] exclude complex types for . 8106ffe [Eric Liang] comments a9350bb [Eric Liang] s/var/val 9c50d4d [Eric Liang] Merge branch 'string-features' into string-features-2 581afb2 [Eric Liang] Merge branch 'master' into string-features 08ae539 [Eric Liang] Merge branch 'string-features' into string-features-2 f99131a [Eric Liang] comments cecec43 [Eric Liang] Merge branch 'string-features' into string-features-2 0bf3c26 [Eric Liang] update docs 4592df2 [Eric Liang] intercept supports 7412a2e [Eric Liang] Fri Jul 24 14:56:51 PDT 2015 3cf848e [Eric Liang] fix the parser 0556c2b [Eric Liang] Merge branch 'string-features' into string-features-2 c302a2c [Eric Liang] fix tests 9d1ac82 [Eric Liang] Merge remote-tracking branch 'upstream/master' into string-features e713da3 [Eric Liang] comments cd231a9 [Eric Liang] Wed Jul 22 17:18:44 PDT 2015 4d79193 [Eric Liang] revert to seq + distinct 169a085 [Eric Liang] tweak functional test a230a47 [Eric Liang] Merge branch 'master' into string-features 72bd6f3 [Eric Liang] fix merge d841cec [Eric Liang] Merge branch 'master' into string-features 5b2c4a2 [Eric Liang] Mon Jul 20 18:45:33 PDT 2015 b01c7c5 [Eric Liang] add test 8a637db [Eric Liang] encoder wip a1d03f4 [Eric Liang] refactor into estimator --- R/pkg/R/mllib.R | 2 +- R/pkg/inst/tests/test_mllib.R | 8 ++ .../apache/spark/ml/feature/RFormula.scala | 52 +++---- .../spark/ml/feature/RFormulaParser.scala | 129 ++++++++++++++++++ .../apache/spark/ml/r/SparkRWrappers.scala | 10 +- .../ml/feature/RFormulaParserSuite.scala | 55 +++++++- 6 files changed, 215 insertions(+), 41 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 258e354081fc1..6a8bacaa552c6 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj")) #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~' and '+'. +#' operators are supported, including '~', '+', '-', and '.'. #' @param data DataFrame for training #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 29152a11688a2..3bef69324770a 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -40,3 +40,11 @@ test_that("predictions match with native glm", { rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) + +test_that("dot minus and intercept vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 0a95b1ee8de6e..0b428d278d908 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -78,13 +78,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** @group getParam */ def getFormula: String = $(formula) + /** Whether the formula specifies fitting an intercept. */ + private[ml] def hasIntercept: Boolean = { + require(parsedFormula.isDefined, "Must call setFormula() first.") + parsedFormula.get.hasIntercept + } + override def fit(dataset: DataFrame): RFormulaModel = { require(parsedFormula.isDefined, "Must call setFormula() first.") + val resolvedFormula = parsedFormula.get.resolve(dataset.schema) // StringType terms and terms representing interactions need to be encoded before assembly. // TODO(ekl) add support for feature interactions - var encoderStages = ArrayBuffer[PipelineStage]() - var tempColumns = ArrayBuffer[String]() - val encodedTerms = parsedFormula.get.terms.map { term => + val encoderStages = ArrayBuffer[PipelineStage]() + val tempColumns = ArrayBuffer[String]() + val encodedTerms = resolvedFormula.terms.map { term => dataset.schema(term) match { case column if column.dataType == StringType => val indexCol = term + "_idx_" + uid @@ -103,7 +110,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R .setOutputCol($(featuresCol)) encoderStages += new ColumnPruner(tempColumns.toSet) val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) - copyValues(new RFormulaModel(uid, parsedFormula.get, pipelineModel).setParent(this)) + copyValues(new RFormulaModel(uid, resolvedFormula, pipelineModel).setParent(this)) } // optimistic schema; does not contain any ML attributes @@ -124,13 +131,13 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R /** * :: Experimental :: * A fitted RFormula. Fitting is required to determine the factor levels of formula terms. - * @param parsedFormula a pre-parsed R formula. + * @param resolvedFormula the fitted R formula. * @param pipelineModel the fitted feature model, including factor to index mappings. */ @Experimental class RFormulaModel private[feature]( override val uid: String, - parsedFormula: ParsedRFormula, + resolvedFormula: ResolvedRFormula, pipelineModel: PipelineModel) extends Model[RFormulaModel] with RFormulaBase { @@ -144,8 +151,8 @@ class RFormulaModel private[feature]( val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else if (schema.exists(_.name == parsedFormula.label)) { - val nullable = schema(parsedFormula.label).dataType match { + } else if (schema.exists(_.name == resolvedFormula.label)) { + val nullable = schema(resolvedFormula.label).dataType match { case _: NumericType | BooleanType => false case _ => true } @@ -158,12 +165,12 @@ class RFormulaModel private[feature]( } override def copy(extra: ParamMap): RFormulaModel = copyValues( - new RFormulaModel(uid, parsedFormula, pipelineModel)) + new RFormulaModel(uid, resolvedFormula, pipelineModel)) - override def toString: String = s"RFormulaModel(${parsedFormula})" + override def toString: String = s"RFormulaModel(${resolvedFormula})" private def transformLabel(dataset: DataFrame): DataFrame = { - val labelName = parsedFormula.label + val labelName = resolvedFormula.label if (hasLabelCol(dataset.schema)) { dataset } else if (dataset.schema.exists(_.name == labelName)) { @@ -207,26 +214,3 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) } - -/** - * Represents a parsed R formula. - */ -private[ml] case class ParsedRFormula(label: String, terms: Seq[String]) - -/** - * Limited implementation of R formula parsing. Currently supports: '~', '+'. - */ -private[ml] object RFormulaParser extends RegexParsers { - def term: Parser[String] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r - - def expr: Parser[List[String]] = term ~ rep("+" ~> term) ^^ { case a ~ list => a :: list } - - def formula: Parser[ParsedRFormula] = - (term ~ "~" ~ expr) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t.distinct) } - - def parse(value: String): ParsedRFormula = parseAll(formula, value) match { - case Success(result, _) => result - case failure: NoSuccess => throw new IllegalArgumentException( - "Could not parse formula: " + value) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala new file mode 100644 index 0000000000000..1ca3b92a7d92a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.types._ + +/** + * Represents a parsed R formula. + */ +private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { + /** + * Resolves formula terms into column names. A schema is necessary for inferring the meaning + * of the special '.' term. Duplicate terms will be removed during resolution. + */ + def resolve(schema: StructType): ResolvedRFormula = { + var includedTerms = Seq[String]() + terms.foreach { + case Dot => + includedTerms ++= simpleTypes(schema).filter(_ != label.value) + case ColumnRef(value) => + includedTerms :+= value + case Deletion(term: Term) => + term match { + case ColumnRef(value) => + includedTerms = includedTerms.filter(_ != value) + case Dot => + // e.g. "- .", which removes all first-order terms + val fromSchema = simpleTypes(schema) + includedTerms = includedTerms.filter(fromSchema.contains(_)) + case _: Deletion => + assert(false, "Deletion terms cannot be nested") + case _: Intercept => + } + case _: Intercept => + } + ResolvedRFormula(label.value, includedTerms.distinct) + } + + /** Whether this formula specifies fitting with an intercept term. */ + def hasIntercept: Boolean = { + var intercept = true + terms.foreach { + case Intercept(enabled) => + intercept = enabled + case Deletion(Intercept(enabled)) => + intercept = !enabled + case _ => + } + intercept + } + + // the dot operator excludes complex column types + private def simpleTypes(schema: StructType): Seq[String] = { + schema.fields.filter(_.dataType match { + case _: NumericType | StringType | BooleanType | _: VectorUDT => true + case _ => false + }).map(_.name) + } +} + +/** + * Represents a fully evaluated and simplified R formula. + */ +private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) + +/** + * R formula terms. See the R formula docs here for more information: + * http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + */ +private[ml] sealed trait Term + +/* R formula reference to all available columns, e.g. "." in a formula */ +private[ml] case object Dot extends Term + +/* R formula reference to a column, e.g. "+ Species" in a formula */ +private[ml] case class ColumnRef(value: String) extends Term + +/* R formula intercept toggle, e.g. "+ 0" in a formula */ +private[ml] case class Intercept(enabled: Boolean) extends Term + +/* R formula deletion of a variable, e.g. "- Species" in a formula */ +private[ml] case class Deletion(term: Term) extends Term + +/** + * Limited implementation of R formula parsing. Currently supports: '~', '+', '-', '.'. + */ +private[ml] object RFormulaParser extends RegexParsers { + def intercept: Parser[Intercept] = + "([01])".r ^^ { case a => Intercept(a == "1") } + + def columnRef: Parser[ColumnRef] = + "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } + + def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot } + + def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { + case op ~ list => list.foldLeft(List(op)) { + case (left, "+" ~ right) => left ++ Seq(right) + case (left, "-" ~ right) => left ++ Seq(Deletion(right)) + } + } + + def formula: Parser[ParsedRFormula] = + (columnRef ~ "~" ~ terms) ^^ { case r ~ "~" ~ t => ParsedRFormula(r, t) } + + def parse(value: String): ParsedRFormula = parseAll(formula, value) match { + case Success(result, _) => result + case failure: NoSuccess => throw new IllegalArgumentException( + "Could not parse formula: " + value) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 1ee080641e3e3..9f70592ccad7e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -32,8 +32,14 @@ private[r] object SparkRWrappers { alpha: Double): PipelineModel = { val formula = new RFormula().setFormula(value) val estimator = family match { - case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha) - case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha) + case "gaussian" => new LinearRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) + case "binomial" => new LogisticRegression() + .setRegParam(lambda) + .setElasticNetParam(alpha) + .setFitIntercept(formula.hasIntercept) } val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index c4b45aee06384..436e66bab09b0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -18,12 +18,17 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ class RFormulaParserSuite extends SparkFunSuite { - private def checkParse(formula: String, label: String, terms: Seq[String]) { - val parsed = RFormulaParser.parse(formula) - assert(parsed.label == label) - assert(parsed.terms == terms) + private def checkParse( + formula: String, + label: String, + terms: Seq[String], + schema: StructType = null) { + val resolved = RFormulaParser.parse(formula).resolve(schema) + assert(resolved.label == label) + assert(resolved.terms == terms) } test("parse simple formulas") { @@ -32,4 +37,46 @@ class RFormulaParserSuite extends SparkFunSuite { checkParse("y ~ ._foo ", "y", Seq("._foo")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } + + test("parse dot") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ .", "a", Seq("b", "c"), schema) + } + + test("parse deletion") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ c - b", "a", Seq("c"), schema) + } + + test("parse additions and deletions in order") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("a ~ . - b + . - c", "a", Seq("b"), schema) + } + + test("dot ignores complex column types") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "tinyint", false) + .add("c", "map", true) + checkParse("a ~ .", "a", Seq("b"), schema) + } + + test("parse intercept") { + assert(RFormulaParser.parse("a ~ b").hasIntercept) + assert(RFormulaParser.parse("a ~ b + 1").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 0").hasIntercept) + assert(RFormulaParser.parse("a ~ b - 1 + 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 0").hasIntercept) + assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) + assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) + } } From b88b868eb378bdb7459978842b5572a0b498f412 Mon Sep 17 00:00:00 2001 From: Joseph Batchik Date: Tue, 28 Jul 2015 14:39:25 -0700 Subject: [PATCH 119/219] [SPARK-8003][SQL] Added virtual column support to Spark Added virtual column support by adding a new resolution role to the query analyzer. Additional virtual columns can be added by adding case expressions to [the new rule](https://github.com/JDrit/spark/blob/virt_columns/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala#L1026) and my modifying the [logical plan](https://github.com/JDrit/spark/blob/virt_columns/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala#L216) to resolve them. This also solves [SPARK-8003](https://issues.apache.org/jira/browse/SPARK-8003) This allows you to perform queries such as: ```sql select spark__partition__id, count(*) as c from table group by spark__partition__id; ``` Author: Joseph Batchik Author: JD Closes #7478 from JDrit/virt_columns and squashes the following commits: 7932bf0 [Joseph Batchik] adding spark__partition__id to hive as well f8a9c6c [Joseph Batchik] merging in master e49da48 [JD] fixes for @rxin's suggestions 60e120b [JD] fixing test in merge 4bf8554 [JD] merging in master c68bc0f [Joseph Batchik] Adding function register ability to SQLContext and adding a function for spark__partition__id() --- .../sql/catalyst/analysis/FunctionRegistry.scala | 2 +- .../scala/org/apache/spark/sql/SQLContext.scala | 11 ++++++++++- .../execution/expressions/SparkPartitionID.scala | 2 +- .../main/scala/org/apache/spark/sql/functions.scala | 2 +- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 7 +++++++ .../expression/NondeterministicSuite.scala | 2 +- .../org/apache/spark/sql/hive/HiveContext.scala | 13 +++++++++++-- .../scala/org/apache/spark/sql/hive/UDFSuite.scala | 9 ++++++++- 8 files changed, 40 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 61ee6f6f71631..9b60943a1e147 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -239,7 +239,7 @@ object FunctionRegistry { } /** See usage above. */ - private def expression[T <: Expression](name: String) + def expression[T <: Expression](name: String) (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = { // See if we can find a constructor that accepts Seq[Expression] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dbb2a09846548..56cd8f22e7cf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,6 +31,8 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder} +import org.apache.spark.sql.execution.expressions.SparkPartitionID import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.errors.DialectException @@ -140,7 +142,14 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin + protected[sql] lazy val functionRegistry: FunctionRegistry = { + val reg = FunctionRegistry.builtin + val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))]( + FunctionExpression[SparkPartitionID]("spark__partition__id") + ) + extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) } + reg + } @transient protected[sql] lazy val analyzer: Analyzer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala index 61ef079d89af5..98c8eab8372aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{IntegerType, DataType} /** * Expression that returns the current partition id of the Spark task. */ -private[sql] case object SparkPartitionID extends LeafExpression with Nondeterministic { +private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index cec61b66b157c..0148991512213 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -741,7 +741,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def sparkPartitionId(): Column = execution.expressions.SparkPartitionID + def sparkPartitionId(): Column = execution.expressions.SparkPartitionID() /** * Computes the square root of the specified float value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index c1516b450cbd4..9b326c16350c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -51,6 +51,13 @@ class UDFSuite extends QueryTest { df.selectExpr("count(distinct a)") } + test("SPARK-8003 spark__partition__id") { + val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") + df.registerTempTable("tmp_table") + checkAnswer(ctx.sql("select spark__partition__id() from tmp_table").toDF(), Row(0)) + ctx.dropTempTable("tmp_table") + } + test("error reporting for incorrect number of arguments") { val df = ctx.emptyDataFrame val e = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala index 1c5a2ed2c0a53..b6e79ff9cc95d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala @@ -27,6 +27,6 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SparkPartitionID") { - checkEvaluation(SparkPartitionID, 0) + checkEvaluation(SparkPartitionID(), 0) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 110f51a305861..8b35c1275f388 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -38,6 +38,9 @@ import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder} +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo +import org.apache.spark.sql.execution.expressions.SparkPartitionID import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} @@ -372,8 +375,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin) + override protected[sql] lazy val functionRegistry: FunctionRegistry = { + val reg = new HiveFunctionRegistry(FunctionRegistry.builtin) + val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))]( + FunctionExpression[SparkPartitionID]("spark__partition__id") + ) + extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) } + reg + } /* An analyzer that uses the Hive metastore. */ @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 4056dee777574..9cea5d413c817 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{Row, QueryTest} case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.hive.test.TestHive + import ctx.implicits._ test("UDF case insensitive") { ctx.udf.register("random0", () => { Math.random() }) @@ -33,4 +34,10 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } + + test("SPARK-8003 spark__partition__id") { + val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying") + ctx.registerDataFrameAsTable(df, "test_table") + checkAnswer(ctx.sql("select spark__partition__id() from test_table LIMIT 1").toDF(), Row(0)) + } } From 198d181dfb2c04102afe40680a4637d951e92c0b Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 28 Jul 2015 15:00:25 -0700 Subject: [PATCH 120/219] [SPARK-7105] [PYSPARK] [MLLIB] Support model save/load in GMM This PR introduces save / load for GMM's in python API. Also I refactored `GaussianMixtureModel` and inherited it from `JavaModelWrapper` with model being `GaussianMixtureModelWrapper`, a wrapper which provides convenience methods to `GaussianMixtureModel` (due to serialization and deserialization issues) and I moved the creation of gaussians to the scala backend. Author: MechCoder Closes #7617 from MechCoder/python_gmm_save_load and squashes the following commits: 9c305aa [MechCoder] [SPARK-7105] [PySpark] [MLlib] Support model save/load in GMM --- .../python/GaussianMixtureModelWrapper.scala | 53 +++++++++++++ .../mllib/api/python/PythonMLLibAPI.scala | 13 +--- python/pyspark/mllib/clustering.py | 75 +++++++++++++------ python/pyspark/mllib/util.py | 6 ++ 4 files changed, 114 insertions(+), 33 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala new file mode 100644 index 0000000000000..0ec88ef77d695 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import java.util.{List => JList} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrix} +import org.apache.spark.mllib.clustering.GaussianMixtureModel + +/** + * Wrapper around GaussianMixtureModel to provide helper methods in Python + */ +private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { + val weights: Vector = Vectors.dense(model.weights) + val k: Int = weights.size + + /** + * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian + */ + val gaussians: JList[Object] = { + val modelGaussians = model.gaussians + var i = 0 + var mu = ArrayBuffer.empty[Vector] + var sigma = ArrayBuffer.empty[Matrix] + while (i < k) { + mu += modelGaussians(i).mu + sigma += modelGaussians(i).sigma + i += 1 + } + List(mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index fda8d5a0b048f..6f080d32bbf4d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -364,7 +364,7 @@ private[python] class PythonMLLibAPI extends Serializable { seed: java.lang.Long, initialModelWeights: java.util.ArrayList[Double], initialModelMu: java.util.ArrayList[Vector], - initialModelSigma: java.util.ArrayList[Matrix]): JList[Object] = { + initialModelSigma: java.util.ArrayList[Matrix]): GaussianMixtureModelWrapper = { val gmmAlg = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) @@ -382,16 +382,7 @@ private[python] class PythonMLLibAPI extends Serializable { if (seed != null) gmmAlg.setSeed(seed) try { - val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) - var wt = ArrayBuffer.empty[Double] - var mu = ArrayBuffer.empty[Vector] - var sigma = ArrayBuffer.empty[Matrix] - for (i <- 0 until model.k) { - wt += model.weights(i) - mu += model.gaussians(i).mu - sigma += model.gaussians(i).sigma - } - List(Vectors.dense(wt.toArray), mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava + new GaussianMixtureModelWrapper(gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))) } finally { data.rdd.unpersist(blocking = false) } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 58ad99d46e23b..900ade248c386 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -152,11 +152,19 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" return KMeansModel([c.toArray() for c in centers]) -class GaussianMixtureModel(object): +@inherit_doc +class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader): + + """ + .. note:: Experimental - """A clustering model derived from the Gaussian Mixture Model method. + A clustering model derived from the Gaussian Mixture Model method. >>> from pyspark.mllib.linalg import Vectors, DenseMatrix + >>> from numpy.testing import assert_equal + >>> from shutil import rmtree + >>> import os, tempfile + >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1, ... 0.9,0.8,0.75,0.935, ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2)) @@ -169,6 +177,25 @@ class GaussianMixtureModel(object): True >>> labels[4]==labels[5] True + + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = GaussianMixtureModel.load(sc, path) + >>> assert_equal(model.weights, sameModel.weights) + >>> mus, sigmas = list( + ... zip(*[(g.mu, g.sigma) for g in model.gaussians])) + >>> sameMus, sameSigmas = list( + ... zip(*[(g.mu, g.sigma) for g in sameModel.gaussians])) + >>> mus == sameMus + True + >>> sigmas == sameSigmas + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + >>> data = array([-5.1971, -2.5359, -3.8220, ... -5.2211, -5.0602, 4.7118, ... 6.8989, 3.4592, 4.6322, @@ -182,25 +209,15 @@ class GaussianMixtureModel(object): True >>> labels[3]==labels[4] True - >>> clusterdata_3 = sc.parallelize(data.reshape(15, 1)) - >>> im = GaussianMixtureModel([0.5, 0.5], - ... [MultivariateGaussian(Vectors.dense([-1.0]), DenseMatrix(1, 1, [1.0])), - ... MultivariateGaussian(Vectors.dense([1.0]), DenseMatrix(1, 1, [1.0]))]) - >>> model = GaussianMixture.train(clusterdata_3, 2, initialModel=im) """ - def __init__(self, weights, gaussians): - self._weights = weights - self._gaussians = gaussians - self._k = len(self._weights) - @property def weights(self): """ Weights for each Gaussian distribution in the mixture, where weights[i] is the weight for Gaussian i, and weights.sum == 1. """ - return self._weights + return array(self.call("weights")) @property def gaussians(self): @@ -208,12 +225,14 @@ def gaussians(self): Array of MultivariateGaussian where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i. """ - return self._gaussians + return [ + MultivariateGaussian(gaussian[0], gaussian[1]) + for gaussian in zip(*self.call("gaussians"))] @property def k(self): """Number of gaussians in mixture.""" - return self._k + return len(self.weights) def predict(self, x): """ @@ -238,17 +257,30 @@ def predictSoft(self, x): :return: membership_matrix. RDD of array of double values. """ if isinstance(x, RDD): - means, sigmas = zip(*[(g.mu, g.sigma) for g in self._gaussians]) + means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians]) membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector), - _convert_to_vector(self._weights), means, sigmas) + _convert_to_vector(self.weights), means, sigmas) return membership_matrix.map(lambda x: pyarray.array('d', x)) else: raise TypeError("x should be represented by an RDD, " "but got %s." % type(x)) + @classmethod + def load(cls, sc, path): + """Load the GaussianMixtureModel from disk. + + :param sc: SparkContext + :param path: str, path to where the model is stored. + """ + model = cls._load_java(sc, path) + wrapper = sc._jvm.GaussianMixtureModelWrapper(model) + return cls(wrapper) + class GaussianMixture(object): """ + .. note:: Experimental + Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm. :param data: RDD of data points @@ -271,11 +303,10 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] - weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), - k, convergenceTol, maxIterations, seed, - initialModelWeights, initialModelMu, initialModelSigma) - mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] - return GaussianMixtureModel(weight, mvg_obj) + java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), + k, convergenceTol, maxIterations, seed, + initialModelWeights, initialModelMu, initialModelSigma) + return GaussianMixtureModel(java_model) class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 875d3b2d642c6..916de2d6fcdbd 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -21,7 +21,9 @@ if sys.version > '3': xrange = range + basestring = str +from pyspark import SparkContext from pyspark.mllib.common import callMLlibFunc, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector @@ -223,6 +225,10 @@ class JavaSaveable(Saveable): """ def save(self, sc, path): + if not isinstance(sc, SparkContext): + raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) + if not isinstance(path, basestring): + raise TypeError("path should be a basestring, got type %s" % type(path)) self._java_model.save(sc._jsc.sc(), path) From 21825529eae66293ec5d8638911303fa54944dd5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 28 Jul 2015 15:56:19 -0700 Subject: [PATCH 121/219] [SPARK-9247] [SQL] Use BytesToBytesMap for broadcast join This PR introduce BytesToBytesMap to UnsafeHashedRelation, use it in executor for better performance. It serialize all the key and values from java HashMap, put them into a BytesToBytesMap while deserializing. All the values for a same key are stored continuous to have better memory locality. This PR also address the comments for #7480 , do some clean up. Author: Davies Liu Closes #7592 from davies/unsafe_map2 and squashes the following commits: 42c578a [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_map2 fd09528 [Davies Liu] remove thread local cache and update docs 1c5ad8d [Davies Liu] fix test 5eb1b5a [Davies Liu] address comments in #7480 46f1f22 [Davies Liu] fix style fc221e0 [Davies Liu] use BytesToBytesMap for broadcast join --- .../execution/joins/BroadcastHashJoin.scala | 2 +- .../joins/BroadcastHashOuterJoin.scala | 2 +- .../joins/BroadcastLeftSemiJoinHash.scala | 6 +- .../joins/BroadcastNestedLoopJoin.scala | 36 ++-- .../spark/sql/execution/joins/HashJoin.scala | 35 ++-- .../sql/execution/joins/HashOuterJoin.scala | 34 ++-- .../sql/execution/joins/HashSemiJoin.scala | 14 +- .../sql/execution/joins/HashedRelation.scala | 166 ++++++++++++++---- .../execution/joins/LeftSemiJoinHash.scala | 2 +- .../execution/joins/ShuffledHashJoin.scala | 2 +- .../joins/ShuffledHashOuterJoin.scala | 8 +- .../execution/joins/HashedRelationSuite.scala | 28 +-- 12 files changed, 214 insertions(+), 121 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index abaa4a6ce86a2..624efc1b1d734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -62,7 +62,7 @@ case class BroadcastHashJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = buildHashRelation(input.iterator) + val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size) sparkContext.broadcast(hashed) }(BroadcastHashJoin.broadcastHashJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index c9d1a880f4ef4..77e7fe71009b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin( private val broadcastFuture = future { // Note that we use .execute().collect() because we don't want to convert data to Scala types val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect() - val hashed = buildHashRelation(input.iterator) + val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size) sparkContext.broadcast(hashed) }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala index f71c0ce352904..a60593911f94f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala @@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash( condition: Option[Expression]) extends BinaryNode with HashSemiJoin { protected override def doExecute(): RDD[InternalRow] = { - val buildIter = right.execute().map(_.copy()).collect().toIterator + val input = right.execute().map(_.copy()).collect() if (condition.isEmpty) { - val hashSet = buildKeyHashSet(buildIter) + val hashSet = buildKeyHashSet(input.toIterator) val broadcastedRelation = sparkContext.broadcast(hashSet) left.execute().mapPartitions { streamIter => hashSemiJoin(streamIter, broadcastedRelation.value) } } else { - val hashRelation = buildHashRelation(buildIter) + val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size) val broadcastedRelation = sparkContext.broadcast(hashRelation) left.execute().mapPartitions { streamIter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 700636966f8be..83b726a8e2897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -47,13 +47,11 @@ case class BroadcastNestedLoopJoin( override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows override def canProcessUnsafeRows: Boolean = true - @transient private[this] lazy val resultProjection: Projection = { + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { if (outputsUnsafeRows) { UnsafeProjection.create(schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -96,7 +94,6 @@ case class BroadcastNestedLoopJoin( var streamRowMatched = false while (i < broadcastedRelation.value.size) { - // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) buildSide match { case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => @@ -135,17 +132,26 @@ case class BroadcastNestedLoopJoin( val buf: CompactBuffer[InternalRow] = new CompactBuffer() var i = 0 val rel = broadcastedRelation.value - while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { - (joinType, buildSide) match { - case (RightOuter | FullOuter, BuildRight) => - buf += resultProjection(new JoinedRow(leftNulls, rel(i))) - case (LeftOuter | FullOuter, BuildLeft) => - buf += resultProjection(new JoinedRow(rel(i), rightNulls)) - case _ => + (joinType, buildSide) match { + case (RightOuter | FullOuter, BuildRight) => + val joinedRow = new JoinedRow + joinedRow.withLeft(leftNulls) + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + buf += resultProjection(joinedRow.withRight(rel(i))).copy() + } + i += 1 } - } - i += 1 + case (LeftOuter | FullOuter, BuildLeft) => + val joinedRow = new JoinedRow + joinedRow.withRight(rightNulls) + while (i < rel.length) { + if (!allIncludedBroadcastTuples.contains(i)) { + buf += resultProjection(joinedRow.withLeft(rel(i))).copy() + } + i += 1 + } + case _ => } buf.toSeq } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 46ab5b0d1cc6d..6b3d1652923fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.util.collection.CompactBuffer trait HashJoin { @@ -44,16 +43,24 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - protected[this] def supportUnsafe: Boolean = { + protected[this] def isUnsafeMode: Boolean = { (self.codegenEnabled && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode + + @transient protected lazy val buildSideKeyGenerator: Projection = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildPlan.output) + } else { + newMutableProjection(buildKeys, buildPlan.output)() + } @transient protected lazy val streamSideKeyGenerator: Projection = - if (supportUnsafe) { + if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newMutableProjection(streamedKeys, streamedPlan.output)() @@ -65,18 +72,16 @@ trait HashJoin { { new Iterator[InternalRow] { private[this] var currentStreamedRow: InternalRow = _ - private[this] var currentHashMatches: CompactBuffer[InternalRow] = _ + private[this] var currentHashMatches: Seq[InternalRow] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. private[this] val joinRow = new JoinedRow - private[this] val resultProjection: Projection = { - if (supportUnsafe) { + private[this] val resultProjection: (InternalRow) => InternalRow = { + if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -122,12 +127,4 @@ trait HashJoin { } } } - - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, buildKeys, buildPlan) - } else { - HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 6bf2f82954046..7e671e7914f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -75,30 +75,36 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - protected[this] def supportUnsafe: Boolean = { + protected[this] def isUnsafeMode: Boolean = { (self.codegenEnabled && joinType != FullOuter && UnsafeProjection.canSupport(buildKeys) && UnsafeProjection.canSupport(self.schema)) } - override def outputsUnsafeRows: Boolean = supportUnsafe - override def canProcessUnsafeRows: Boolean = supportUnsafe + override def outputsUnsafeRows: Boolean = isUnsafeMode + override def canProcessUnsafeRows: Boolean = isUnsafeMode + override def canProcessSafeRows: Boolean = !isUnsafeMode - protected[this] def streamedKeyGenerator(): Projection = { - if (supportUnsafe) { + @transient protected lazy val buildKeyGenerator: Projection = + if (isUnsafeMode) { + UnsafeProjection.create(buildKeys, buildPlan.output) + } else { + newMutableProjection(buildKeys, buildPlan.output)() + } + + @transient protected[this] lazy val streamedKeyGenerator: Projection = { + if (isUnsafeMode) { UnsafeProjection.create(streamedKeys, streamedPlan.output) } else { newProjection(streamedKeys, streamedPlan.output) } } - @transient private[this] lazy val resultProjection: Projection = { - if (supportUnsafe) { + @transient private[this] lazy val resultProjection: InternalRow => InternalRow = { + if (isUnsafeMode) { UnsafeProjection.create(self.schema) } else { - new Projection { - override def apply(r: InternalRow): InternalRow = r - } + identity[InternalRow] } } @@ -230,12 +236,4 @@ trait HashOuterJoin { hashTable } - - protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, buildKeys, buildPlan) - } else { - HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output)) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index 7f49264d40354..97fde8f975bfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -35,11 +35,13 @@ trait HashSemiJoin { protected[this] def supportUnsafe: Boolean = { (self.codegenEnabled && UnsafeProjection.canSupport(leftKeys) && UnsafeProjection.canSupport(rightKeys) - && UnsafeProjection.canSupport(left.schema)) + && UnsafeProjection.canSupport(left.schema) + && UnsafeProjection.canSupport(right.schema)) } - override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows + override def outputsUnsafeRows: Boolean = supportUnsafe override def canProcessUnsafeRows: Boolean = supportUnsafe + override def canProcessSafeRows: Boolean = !supportUnsafe @transient protected lazy val leftKeyGenerator: Projection = if (supportUnsafe) { @@ -87,14 +89,6 @@ trait HashSemiJoin { }) } - protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = { - if (supportUnsafe) { - UnsafeHashedRelation(buildIter, rightKeys, right) - } else { - HashedRelation(buildIter, newProjection(rightKeys, right.output)) - } - } - protected def hashSemiJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8d5731afd59b8..9c058f1f72fe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.util.collection.CompactBuffer @@ -32,7 +35,7 @@ import org.apache.spark.util.collection.CompactBuffer * object. */ private[joins] sealed trait HashedRelation { - def get(key: InternalRow): CompactBuffer[InternalRow] + def get(key: InternalRow): Seq[InternalRow] // This is a helper method to implement Externalizable, and is used by // GeneralHashedRelation and UniqueKeyHashedRelation @@ -59,9 +62,9 @@ private[joins] final class GeneralHashedRelation( private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private def this() = this(null) // Needed for serialization - override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key) + override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key) override def writeExternal(out: ObjectOutput): Unit = { writeBytes(out, SparkSqlSerializer.serialize(hashTable)) @@ -81,9 +84,9 @@ private[joins] final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private def this() = this(null) // Needed for serialization - override def get(key: InternalRow): CompactBuffer[InternalRow] = { + override def get(key: InternalRow): Seq[InternalRow] = { val v = hashTable.get(key) if (v eq null) null else CompactBuffer(v) } @@ -109,6 +112,10 @@ private[joins] object HashedRelation { keyGenerator: Projection, sizeEstimate: Int = 64): HashedRelation = { + if (keyGenerator.isInstanceOf[UnsafeProjection]) { + return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + } + // TODO: Use Spark's HashMap implementation. val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate) var currentRow: InternalRow = null @@ -149,31 +156,133 @@ private[joins] object HashedRelation { } } - /** - * A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a - * sequence of values. + * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key + * into a sequence of values. + * + * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use + * BytesToBytesMap for better memory performance (multiple values for the same are stored as a + * continuous byte array. * - * TODO(davies): use BytesToBytesMap + * It's serialized in the following format: + * [number of keys] + * [size of key] [size of all values in bytes] [key bytes] [bytes for all values] + * ... + * + * All the values are serialized as following: + * [number of fields] [number of bytes] [underlying bytes of UnsafeRow] + * ... */ private[joins] final class UnsafeHashedRelation( private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]]) extends HashedRelation with Externalizable { - def this() = this(null) // Needed for serialization + private[joins] def this() = this(null) // Needed for serialization + + // Use BytesToBytesMap in executor for better performance (it's created when deserialization) + @transient private[this] var binaryMap: BytesToBytesMap = _ - override def get(key: InternalRow): CompactBuffer[InternalRow] = { + override def get(key: InternalRow): Seq[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] - // Thanks to type eraser - hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]] + + if (binaryMap != null) { + // Used in Broadcast join + val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes) + if (loc.isDefined) { + val buffer = CompactBuffer[UnsafeRow]() + + val base = loc.getValueAddress.getBaseObject + var offset = loc.getValueAddress.getBaseOffset + val last = loc.getValueAddress.getBaseOffset + loc.getValueLength + while (offset < last) { + val numFields = PlatformDependent.UNSAFE.getInt(base, offset) + val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4) + offset += 8 + + val row = new UnsafeRow + row.pointTo(base, offset, numFields, sizeInBytes) + buffer += row + offset += sizeInBytes + } + buffer + } else { + null + } + + } else { + // Use the JavaHashMap in Local mode or ShuffleHashJoin + hashTable.get(unsafeKey) + } } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + out.writeInt(hashTable.size()) + + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + val key = entry.getKey + val values = entry.getValue + + // write all the values as single byte array + var totalSize = 0L + var i = 0 + while (i < values.size) { + totalSize += values(i).getSizeInBytes + 4 + 4 + i += 1 + } + assert(totalSize < Integer.MAX_VALUE, "values are too big") + + // [key size] [values size] [key bytes] [values bytes] + out.writeInt(key.getSizeInBytes) + out.writeInt(totalSize.toInt) + out.write(key.getBytes) + i = 0 + while (i < values.size) { + // [num of fields] [num of bytes] [row bytes] + // write the integer in native order, so they can be read by UNSAFE.getInt() + if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { + out.writeInt(values(i).numFields()) + out.writeInt(values(i).getSizeInBytes) + } else { + out.writeInt(Integer.reverseBytes(values(i).numFields())) + out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes)) + } + out.write(values(i).getBytes) + i += 1 + } + } } override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + val nKeys = in.readInt() + // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory + val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + binaryMap = new BytesToBytesMap(memoryManager, nKeys * 2) // reduce hash collision + + var i = 0 + var keyBuffer = new Array[Byte](1024) + var valuesBuffer = new Array[Byte](1024) + while (i < nKeys) { + val keySize = in.readInt() + val valuesSize = in.readInt() + if (keySize > keyBuffer.size) { + keyBuffer = new Array[Byte](keySize) + } + in.readFully(keyBuffer, 0, keySize) + if (valuesSize > valuesBuffer.size) { + valuesBuffer = new Array[Byte](valuesSize) + } + in.readFully(valuesBuffer, 0, valuesSize) + + // put it into binary map + val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize) + assert(!loc.isDefined, "Duplicated key found!") + loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize, + valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize) + i += 1 + } } } @@ -181,33 +290,14 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - buildKeys: Seq[Expression], - buildPlan: SparkPlan, - sizeEstimate: Int = 64): HashedRelation = { - val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output)) - apply(input, boundedKeys, buildPlan.schema, sizeEstimate) - } - - // Used for tests - def apply( - input: Iterator[InternalRow], - buildKeys: Seq[Expression], - rowSchema: StructType, + keyGenerator: UnsafeProjection, sizeEstimate: Int): HashedRelation = { - // TODO: Use BytesToBytesMap. val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate) - val toUnsafe = UnsafeProjection.create(rowSchema) - val keyGenerator = UnsafeProjection.create(buildKeys) // Create a mapping of buildKeys -> rows while (input.hasNext) { - val currentRow = input.next() - val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) { - currentRow.asInstanceOf[UnsafeRow] - } else { - toUnsafe(currentRow) - } + val unsafeRow = input.next().asInstanceOf[UnsafeRow] val rowKey = keyGenerator(unsafeRow) if (!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 874712a4e739f..26a664104d6fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -46,7 +46,7 @@ case class LeftSemiJoinHash( val hashSet = buildKeyHashSet(buildIter) hashSemiJoin(streamIter, hashSet) } else { - val hashRelation = buildHashRelation(buildIter) + val hashRelation = HashedRelation(buildIter, rightKeyGenerator) hashSemiJoin(streamIter, hashRelation) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 948d0ccebceb0..5439e10a60b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,7 +45,7 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => - val hashed = buildHashRelation(buildIter) + val hashed = HashedRelation(buildIter, buildSideKeyGenerator) hashJoin(streamIter, hashed) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index f54f1edd38ec8..d29b593207c4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -50,8 +50,8 @@ case class ShuffledHashOuterJoin( // TODO this probably can be replaced by external sort (sort merged join?) joinType match { case LeftOuter => - val hashed = buildHashRelation(rightIter) - val keyGenerator = streamedKeyGenerator() + val hashed = HashedRelation(rightIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator leftIter.flatMap( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withLeft(currentRow) @@ -59,8 +59,8 @@ case class ShuffledHashOuterJoin( }) case RightOuter => - val hashed = buildHashRelation(leftIter) - val keyGenerator = streamedKeyGenerator() + val hashed = HashedRelation(leftIter, buildKeyGenerator) + val keyGenerator = streamedKeyGenerator rightIter.flatMap ( currentRow => { val rowKey = keyGenerator(currentRow) joinedRow.withRight(currentRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 9dd2220f0967e..8b1a9b21a96b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.execution.joins +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.SparkSqlSerializer -import org.apache.spark.sql.types.{StructField, StructType, IntegerType} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.collection.CompactBuffer @@ -64,27 +65,34 @@ class HashedRelationSuite extends SparkFunSuite { } test("UnsafeHashedRelation") { + val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) + val toUnsafe = UnsafeProjection.create(schema) + val unsafeData = data.map(toUnsafe(_).copy()).toArray + val buildKey = Seq(BoundReference(0, IntegerType, false)) - val schema = StructType(StructField("a", IntegerType, true) :: Nil) - val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1) + val keyGenerator = UnsafeProjection.create(buildKey) + val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) - val toUnsafeKey = UnsafeProjection.create(schema) - val unsafeData = data.map(toUnsafeKey(_).copy()).toArray assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) - assert(hashed.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed.get(toUnsafe(InternalRow(10))) === null) val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) data2 += unsafeData(2).copy() assert(hashed.get(unsafeData(2)) === data2) - val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed)) - .asInstanceOf[UnsafeHashedRelation] + val os = new ByteArrayOutputStream() + val out = new ObjectOutputStream(os) + hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out) + out.flush() + val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) + val hashed2 = new UnsafeHashedRelation() + hashed2.readExternal(in) assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) - assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null) + assert(hashed2.get(toUnsafe(InternalRow(10))) === null) assert(hashed2.get(unsafeData(2)) === data2) } } From 59b92add7cc9cca1eaf0c558edb7c4add66c284f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 16:04:48 -0700 Subject: [PATCH 122/219] [SPARK-9393] [SQL] Fix several error-handling bugs in ScriptTransform operator SparkSQL's ScriptTransform operator has several serious bugs which make debugging fairly difficult: - If exceptions are thrown in the writing thread then the child process will not be killed, leading to a deadlock because the reader thread will block while waiting for input that will never arrive. - TaskContext is not propagated to the writer thread, which may cause errors in upstream pipelined operators. - Exceptions which occur in the writer thread are not propagated to the main reader thread, which may cause upstream errors to be silently ignored instead of killing the job. This can lead to silently incorrect query results. - The writer thread is not a daemon thread, but it should be. In addition, the code in this file is extremely messy: - Lots of fields are nullable but the nullability isn't clearly explained. - Many confusing variable names: for instance, there are variables named `ite` and `iterator` that are defined in the same scope. - Some code was misindented. - The `*serdeClass` variables are actually expected to be single-quoted strings, which is really confusing: I feel that this parsing / extraction should be performed in the analyzer, not in the operator itself. - There were no unit tests for the operator itself, only end-to-end tests. This pull request addresses these issues, borrowing some error-handling techniques from PySpark's PythonRDD. Author: Josh Rosen Closes #7710 from JoshRosen/script-transform and squashes the following commits: 16c44e2 [Josh Rosen] Update some comments 983f200 [Josh Rosen] Use unescapeSQLString instead of stripQuotes 6a06a8c [Josh Rosen] Clean up handling of quotes in serde class name 494cde0 [Josh Rosen] Propagate TaskContext to writer thread 323bb2b [Josh Rosen] Fix error-swallowing bug b31258d [Josh Rosen] Rename iterator variables to disambiguate. 88278de [Josh Rosen] Split ScriptTransformation writer thread into own class. 8b162b6 [Josh Rosen] Add failing test which demonstrates exception masking issue 4ee36a2 [Josh Rosen] Kill script transform subprocess when error occurs in input writer. bd4c948 [Josh Rosen] Skip launching of external command for empty partitions. b43e4ec [Josh Rosen] Clean up nullability in ScriptTransformation fa18d26 [Josh Rosen] Add basic unit test for script transform with 'cat' command. --- .../spark/sql/execution/SparkPlanTest.scala | 27 +- .../org/apache/spark/sql/hive/HiveQl.scala | 10 +- .../hive/execution/ScriptTransformation.scala | 280 +++++++++++------- .../execution/ScriptTransformationSuite.scala | 123 ++++++++ 4 files changed, 317 insertions(+), 123 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 6a8f394545816..f46855edfe0de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row} +import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row} import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -33,11 +33,13 @@ import scala.util.control.NonFatal */ class SparkPlanTest extends SparkFunSuite { + protected def sqlContext: SQLContext = TestSQLContext + /** * Creates a DataFrame from a local Seq of Product. */ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = { - TestSQLContext.implicits.localSeqToDataFrameHolder(data) + sqlContext.implicits.localSeqToDataFrameHolder(data) } /** @@ -98,7 +100,7 @@ class SparkPlanTest extends SparkFunSuite { planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match { + SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -121,7 +123,8 @@ class SparkPlanTest extends SparkFunSuite { planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, sortAnswers: Boolean = true): Unit = { - SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match { + SparkPlanTest.checkAnswer( + input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match { case Some(errorMessage) => fail(errorMessage) case None => } @@ -147,13 +150,14 @@ object SparkPlanTest { input: DataFrame, planFunction: SparkPlan => SparkPlan, expectedPlanFunction: SparkPlan => SparkPlan, - sortAnswers: Boolean): Option[String] = { + sortAnswers: Boolean, + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.queryExecution.sparkPlan) val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan) val expectedAnswer: Seq[Row] = try { - executePlan(expectedOutputPlan) + executePlan(expectedOutputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -168,7 +172,7 @@ object SparkPlanTest { } val actualAnswer: Seq[Row] = try { - executePlan(outputPlan) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -207,12 +211,13 @@ object SparkPlanTest { input: Seq[DataFrame], planFunction: Seq[SparkPlan] => SparkPlan, expectedAnswer: Seq[Row], - sortAnswers: Boolean): Option[String] = { + sortAnswers: Boolean, + sqlContext: SQLContext): Option[String] = { val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan)) val sparkAnswer: Seq[Row] = try { - executePlan(outputPlan) + executePlan(outputPlan, sqlContext) } catch { case NonFatal(e) => val errorMessage = @@ -275,10 +280,10 @@ object SparkPlanTest { } } - private def executePlan(outputPlan: SparkPlan): Seq[Row] = { + private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = TestSQLContext.prepareForExecution.execute( + val resolvedPlan = sqlContext.prepareForExecution.execute( outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 2f79b0aad045c..e6df64d2642bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -874,15 +874,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } def matchSerDe(clause: Seq[ASTNode]) - : (Seq[(String, String)], String, Seq[(String, String)]) = clause match { + : (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match { case Token("TOK_SERDEPROPS", propsClause) :: Nil => val rowFormat = propsClause.map { case Token(name, Token(value, Nil) :: Nil) => (name, value) } - (rowFormat, "", Nil) + (rowFormat, None, Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, serdeClass, Nil) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil) case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Token("TOK_TABLEPROPERTIES", @@ -891,9 +891,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => (name, value) } - (Nil, serdeClass, serdeProps) + (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps) - case Nil => (Nil, "", Nil) + case Nil => (Nil, None, Nil) } val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 205e622195f09..741c705e2a253 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -17,15 +17,18 @@ package org.apache.spark.sql.hive.execution -import java.io.{BufferedReader, DataInputStream, DataOutputStream, EOFException, InputStreamReader} +import java.io._ import java.util.Properties +import javax.annotation.Nullable import scala.collection.JavaConversions._ +import scala.util.control.NonFatal import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.spark.{TaskContext, Logging} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters @@ -56,21 +59,53 @@ case class ScriptTransformation( override def otherCopyArgs: Seq[HiveContext] = sc :: Nil protected override def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => + def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) - // We need to start threads connected to the process pipeline: - // 1) The error msg generated by the script process would be hidden. - // 2) If the error msg is too big to chock up the buffer, the input logic would be hung + val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream val errorStream = proc.getErrorStream - val reader = new BufferedReader(new InputStreamReader(inputStream)) - val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + "Thread-ScriptTransformation-STDERR-Consumer").start() + + val outputProjection = new InterpretedProjection(input, child.output) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = new ScriptTransformationWriterThread( + inputIterator, + outputProjection, + inputSerde, + inputSoi, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get() + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + ioschema.initOutputSerDe(output).getOrElse((null, null)) + } - val iterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + val reader = new BufferedReader(new InputStreamReader(inputStream)) + val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { var cacheRow: InternalRow = null var curLine: String = null var eof: Boolean = false @@ -79,12 +114,26 @@ case class ScriptTransformation( if (outputSerde == null) { if (curLine == null) { curLine = reader.readLine() - curLine != null + if (curLine == null) { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } else { + true + } } else { true } } else { - !eof + if (eof) { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + false + } else { + true + } } } @@ -110,11 +159,11 @@ case class ScriptTransformation( } i += 1 }) - return mutableRow + mutableRow } catch { case e: EOFException => eof = true - return null + null } } @@ -146,49 +195,83 @@ case class ScriptTransformation( } } - val (inputSerde, inputSoi) = ioschema.initInputSerDe(input) - val dataOutputStream = new DataOutputStream(outputStream) - val outputProjection = new InterpretedProjection(input, child.output) + writerThread.start() - // TODO make the 2048 configurable? - val stderrBuffer = new CircularBuffer(2048) - // Consume the error stream from the pipeline, otherwise it will be blocked if - // the pipeline is full. - new RedirectThread(errorStream, // input stream from the pipeline - stderrBuffer, // output to a circular buffer - "Thread-ScriptTransformation-STDERR-Consumer").start() + outputIterator + } - // Put the write(output to the pipeline) into a single thread - // and keep the collector as remain in the main thread. - // otherwise it will causes deadlock if the data size greater than - // the pipeline / buffer capacity. - new Thread(new Runnable() { - override def run(): Unit = { - Utils.tryWithSafeFinally { - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize( - row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) - } - } - outputStream.close() - } { - if (proc.waitFor() != 0) { - logError(stderrBuffer.toString) // log the stderr circular buffer - } - } - } - }, "Thread-ScriptTransformation-Feed").start() + child.execute().mapPartitions { iter => + if (iter.hasNext) { + processIterator(iter) + } else { + // If the input iterator has no rows then do not launch the external script. + Iterator.empty + } + } + } +} - iterator +private class ScriptTransformationWriterThread( + iter: Iterator[InternalRow], + outputProjection: Projection, + @Nullable inputSerde: AbstractSerDe, + @Nullable inputSoi: ObjectInspector, + ioschema: HiveScriptIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext + ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { + + setDaemon(true) + + @volatile private var _exception: Throwable = null + + /** Contains the exception thrown while writing the parent iterator to the external process. */ + def exception: Option[Throwable] = Option(_exception) + + override def run(): Unit = Utils.logUncaughtExceptions { + TaskContext.setTaskContext(taskContext) + + val dataOutputStream = new DataOutputStream(outputStream) + + // We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so + // let's use a variable to record whether the `finally` block was hit due to an exception + var threwException: Boolean = true + try { + iter.map(outputProjection).foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + outputStream.write(data) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } + } + outputStream.close() + threwException = false + } catch { + case NonFatal(e) => + // An error occurred while writing input, so kill the child process. According to the + // Javadoc this call will not throw an exception: + _exception = e + proc.destroy() + throw e + } finally { + try { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer + } + } catch { + case NonFatal(exceptionFromFinallyBlock) => + if (!threwException) { + throw exceptionFromFinallyBlock + } else { + log.error("Exception in finally block", exceptionFromFinallyBlock) + } + } } } } @@ -200,33 +283,43 @@ private[hive] case class HiveScriptIOSchema ( inputRowFormat: Seq[(String, String)], outputRowFormat: Seq[(String, String)], - inputSerdeClass: String, - outputSerdeClass: String, + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], inputSerdeProps: Seq[(String, String)], outputSerdeProps: Seq[(String, String)], schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors { - val defaultFormat = Map(("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n")) + private val defaultFormat = Map( + ("TOK_TABLEROWFORMATFIELD", "\t"), + ("TOK_TABLEROWFORMATLINES", "\n") + ) val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - def initInputSerDe(input: Seq[Expression]): (AbstractSerDe, ObjectInspector) = { - val (columns, columnTypes) = parseAttrs(input) - val serde = initSerDe(inputSerdeClass, columns, columnTypes, inputSerdeProps) - (serde, initInputSoi(serde, columns, columnTypes)) + def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { + inputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(input) + val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) + val fieldObjectInspectors = columnTypes.map(toInspector) + val objectInspector = ObjectInspectorFactory + .getStandardStructObjectInspector(columns, fieldObjectInspectors) + .asInstanceOf[ObjectInspector] + (serde, objectInspector) + } } - def initOutputSerDe(output: Seq[Attribute]): (AbstractSerDe, StructObjectInspector) = { - val (columns, columnTypes) = parseAttrs(output) - val serde = initSerDe(outputSerdeClass, columns, columnTypes, outputSerdeProps) - (serde, initOutputputSoi(serde)) + def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + outputSerdeClass.map { serdeClass => + val (columns, columnTypes) = parseAttrs(output) + val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) + val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] + (serde, structObjectInspector) + } } - def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { - + private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) = { val columns = attrs.map { case aref: AttributeReference => aref.name case e: NamedExpression => e.name @@ -242,52 +335,25 @@ case class HiveScriptIOSchema ( (columns, columnTypes) } - def initSerDe(serdeClassName: String, columns: Seq[String], - columnTypes: Seq[DataType], serdeProps: Seq[(String, String)]): AbstractSerDe = { + private def initSerDe( + serdeClassName: String, + columns: Seq[String], + columnTypes: Seq[DataType], + serdeProps: Seq[(String, String)]): AbstractSerDe = { - val serde: AbstractSerDe = if (serdeClassName != "") { - val trimed_class = serdeClassName.split("'")(1) - Utils.classForName(trimed_class) - .newInstance.asInstanceOf[AbstractSerDe] - } else { - null - } + val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe] - if (serde != null) { - val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") + val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",") - var propsMap = serdeProps.map(kv => { - (kv._1.split("'")(1), kv._2.split("'")(1)) - }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) - propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) + var propsMap = serdeProps.map(kv => { + (kv._1.split("'")(1), kv._2.split("'")(1)) + }).toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(",")) + propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames) - val properties = new Properties() - properties.putAll(propsMap) - serde.initialize(null, properties) - } + val properties = new Properties() + properties.putAll(propsMap) + serde.initialize(null, properties) serde } - - def initInputSoi(inputSerde: AbstractSerDe, columns: Seq[String], columnTypes: Seq[DataType]) - : ObjectInspector = { - - if (inputSerde != null) { - val fieldObjectInspectors = columnTypes.map(toInspector(_)) - ObjectInspectorFactory - .getStandardStructObjectInspector(columns, fieldObjectInspectors) - .asInstanceOf[ObjectInspector] - } else { - null - } - } - - def initOutputputSoi(outputSerde: AbstractSerDe): StructObjectInspector = { - if (outputSerde != null) { - outputSerde.getObjectInspector().asInstanceOf[StructObjectInspector] - } else { - null - } - } } - diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala new file mode 100644 index 0000000000000..0875232aede3e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.types.StringType + +class ScriptTransformationSuite extends SparkPlanTest { + + override def sqlContext: SQLContext = TestHive + + private val noSerdeIOSchema = HiveScriptIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + schemaLess = false + ) + + private val serdeIOSchema = noSerdeIOSchema.copy( + inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), + outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) + ) + + test("cat without SerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = noSerdeIOSchema + )(TestHive), + rowsDf.collect()) + } + + test("cat with LazySimpleSerDe") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = serdeIOSchema + )(TestHive), + rowsDf.collect()) + } + + test("script transformation should not swallow errors from upstream operators (no serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = noSerdeIOSchema + )(TestHive), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } + + test("script transformation should not swallow errors from upstream operators (with serde)") { + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => new ScriptTransformation( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = serdeIOSchema + )(TestHive), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + } +} + +private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map { x => + assert(TaskContext.get() != null) // Make sure that TaskContext is defined. + Thread.sleep(1000) // This sleep gives the external process time to start. + throw new IllegalArgumentException("intentional exception") + } + } + override def output: Seq[Attribute] = child.output +} From c5ed36953f840018f603dfde94fcb4651e5246ac Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 28 Jul 2015 16:41:56 -0700 Subject: [PATCH 123/219] [STREAMING] [HOTFIX] Ignore ReceiverTrackerSuite flaky test Author: Tathagata Das Closes #7738 from tdas/ReceiverTrackerSuite-hotfix and squashes the following commits: 00f0ee1 [Tathagata Das] ignore flaky test --- .../apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index e2159bd4f225d..b039233f36316 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -31,7 +31,7 @@ class ReceiverTrackerSuite extends TestSuiteBase { val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") val ssc = new StreamingContext(sparkConf, Milliseconds(100)) - test("Receiver tracker - propagates rate limit") { + ignore("Receiver tracker - propagates rate limit") { object ReceiverStartedWaiter extends StreamingListener { @volatile var started = false From b7f54119f86f916481aeccc67f07e77dc2a924c7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Jul 2015 17:03:59 -0700 Subject: [PATCH 124/219] [SPARK-9420][SQL] Move expressions in sql/core package to catalyst. Since catalyst package already depends on Spark core, we can move those expressions into catalyst, and simplify function registry. This is a followup of #7478. Author: Reynold Xin Closes #7735 from rxin/SPARK-8003 and squashes the following commits: 2ffbdc3 [Reynold Xin] [SPARK-8003][SQL] Move expressions in sql/core package to catalyst. --- .../sql/catalyst/analysis/Analyzer.scala | 3 ++- .../catalyst/analysis/FunctionRegistry.scala | 17 +++++++------- .../MonotonicallyIncreasingID.scala | 3 +-- .../expressions/SparkPartitionID.scala | 3 +-- .../expressions}/NondeterministicSuite.scala | 4 +--- .../org/apache/spark/sql/SQLContext.scala | 11 +-------- .../sql/execution/expressions/package.scala | 23 ------------------- .../org/apache/spark/sql/functions.scala | 4 ++-- .../scala/org/apache/spark/sql/UDFSuite.scala | 4 ++-- .../apache/spark/sql/hive/HiveContext.scala | 13 ++--------- .../org/apache/spark/sql/hive/UDFSuite.scala | 4 ++-- 11 files changed, 23 insertions(+), 66 deletions(-) rename sql/{core/src/main/scala/org/apache/spark/sql/execution => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/expressions/MonotonicallyIncreasingID.scala (95%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution => catalyst/src/main/scala/org/apache/spark/sql/catalyst}/expressions/SparkPartitionID.scala (93%) rename sql/{core/src/test/scala/org/apache/spark/sql/execution/expression => catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions}/NondeterministicSuite.scala (83%) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a723e92114b32..a309ee35ee582 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.expressions._ @@ -25,7 +27,6 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ -import scala.collection.mutable.ArrayBuffer /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9b60943a1e147..372f80d4a8b16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -161,13 +161,6 @@ object FunctionRegistry { expression[ToDegrees]("degrees"), expression[ToRadians]("radians"), - // misc functions - expression[Md5]("md5"), - expression[Sha2]("sha2"), - expression[Sha1]("sha1"), - expression[Sha1]("sha"), - expression[Crc32]("crc32"), - // aggregate functions expression[Average]("avg"), expression[Count]("count"), @@ -229,7 +222,15 @@ object FunctionRegistry { expression[Year]("year"), // collection functions - expression[Size]("size") + expression[Size]("size"), + + // misc functions + expression[Crc32]("crc32"), + expression[Md5]("md5"), + expression[Sha1]("sha"), + expression[Sha1]("sha1"), + expression[Sha2]("sha2"), + expression[SparkPartitionID]("spark_partition_id") ) val builtin: FunctionRegistry = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index eca36b3274420..291b7a5bc3af5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.expressions +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{LongType, DataType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 98c8eab8372aa..3f6480bbf0114 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.expressions +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, LeafExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types.{IntegerType, DataType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala similarity index 83% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala index b6e79ff9cc95d..82894822ab0f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/expression/NondeterministicSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala @@ -15,11 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.expression +package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions. ExpressionEvalHelper -import org.apache.spark.sql.execution.expressions.{SparkPartitionID, MonotonicallyIncreasingID} class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("MonotonicallyIncreasingID") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 56cd8f22e7cf4..dbb2a09846548 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,8 +31,6 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder} -import org.apache.spark.sql.execution.expressions.SparkPartitionID import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.errors.DialectException @@ -142,14 +140,7 @@ class SQLContext(@transient val sparkContext: SparkContext) // TODO how to handle the temp function per user session? @transient - protected[sql] lazy val functionRegistry: FunctionRegistry = { - val reg = FunctionRegistry.builtin - val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))]( - FunctionExpression[SparkPartitionID]("spark__partition__id") - ) - extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) } - reg - } + protected[sql] lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin @transient protected[sql] lazy val analyzer: Analyzer = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala deleted file mode 100644 index 568b7ac2c5987..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/package.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -/** - * Package containing expressions that are specific to Spark runtime. - */ -package object expressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0148991512213..4261a5e7cbeb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -634,7 +634,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID() /** * Return an alternative value `r` if `l` is NaN. @@ -741,7 +741,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def sparkPartitionId(): Column = execution.expressions.SparkPartitionID() + def sparkPartitionId(): Column = SparkPartitionID() /** * Computes the square root of the specified float value. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 9b326c16350c8..d9c8b380ef146 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -51,10 +51,10 @@ class UDFSuite extends QueryTest { df.selectExpr("count(distinct a)") } - test("SPARK-8003 spark__partition__id") { + test("SPARK-8003 spark_partition_id") { val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying") df.registerTempTable("tmp_table") - checkAnswer(ctx.sql("select spark__partition__id() from tmp_table").toDF(), Row(0)) + checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0)) ctx.dropTempTable("tmp_table") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8b35c1275f388..110f51a305861 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -38,9 +38,6 @@ import org.apache.spark.Logging import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{expression => FunctionExpression, FunctionBuilder} -import org.apache.spark.sql.catalyst.expressions.ExpressionInfo -import org.apache.spark.sql.execution.expressions.SparkPartitionID import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ import org.apache.spark.sql.catalyst.{TableIdentifier, ParserDialect} @@ -375,14 +372,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { // Note that HiveUDFs will be overridden by functions registered in this context. @transient - override protected[sql] lazy val functionRegistry: FunctionRegistry = { - val reg = new HiveFunctionRegistry(FunctionRegistry.builtin) - val extendedFunctions = List[(String, (ExpressionInfo, FunctionBuilder))]( - FunctionExpression[SparkPartitionID]("spark__partition__id") - ) - extendedFunctions.foreach { case(name, (info, fun)) => reg.registerFunction(name, info, fun) } - reg - } + override protected[sql] lazy val functionRegistry: FunctionRegistry = + new HiveFunctionRegistry(FunctionRegistry.builtin) /* An analyzer that uses the Hive metastore. */ @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 9cea5d413c817..37afc2142abf7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -35,9 +35,9 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } - test("SPARK-8003 spark__partition__id") { + test("SPARK-8003 spark_partition_id") { val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying") ctx.registerDataFrameAsTable(df, "test_table") - checkAnswer(ctx.sql("select spark__partition__id() from test_table LIMIT 1").toDF(), Row(0)) + checkAnswer(ctx.sql("select spark_partition_id() from test_table LIMIT 1").toDF(), Row(0)) } } From 6662ee21244067180c1bcef0b16107b2979fd933 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 28 Jul 2015 17:42:35 -0700 Subject: [PATCH 125/219] [SPARK-9418][SQL] Use sort-merge join as the default shuffle join. Sort-merge join is more robust in Spark since sorting can be made using the Tungsten sort operator. Author: Reynold Xin Closes #7733 from rxin/smj and squashes the following commits: 61e4d34 [Reynold Xin] Fixed test case. 5ffd731 [Reynold Xin] Fixed JoinSuite. a137dc0 [Reynold Xin] [SPARK-9418][SQL] Use sort-merge join as the default shuffle join. --- .../src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 6 +++--- ...bilitySuite.scala => HashJoinCompatibilitySuite.scala} | 8 ++++---- .../scala/org/apache/spark/sql/hive/StatisticsSuite.scala | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) rename sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/{SortMergeCompatibilitySuite.scala => HashJoinCompatibilitySuite.scala} (97%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 40eba33f595ca..cdb0c7a1c07a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -322,7 +322,7 @@ private[spark] object SQLConf { " memory.") val SORTMERGE_JOIN = booleanConf("spark.sql.planner.sortMergeJoin", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, use sort merge join (as opposed to hash join) by default for large joins.") // This is only used for the thriftserver diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index dfb2a7e099748..666f26bf620e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -79,9 +79,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]), ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]), ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]), - ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]), - ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), + ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]), ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashOuterJoin]), diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala similarity index 97% rename from sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala rename to sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala index 1fe4fe9629c02..1a5ba20404c4e 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HashJoinCompatibilitySuite.scala @@ -23,16 +23,16 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive /** - * Runs the test cases that are included in the hive distribution with sort merge join is true. + * Runs the test cases that are included in the hive distribution with hash joins. */ -class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { +class HashJoinCompatibilitySuite extends HiveCompatibilitySuite { override def beforeAll() { super.beforeAll() - TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) + TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) } override def afterAll() { - TestHive.setConf(SQLConf.SORTMERGE_JOIN, false) + TestHive.setConf(SQLConf.SORTMERGE_JOIN, true) super.afterAll() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f067ea0d4fc75..bc72b0172a467 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -172,7 +172,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j } assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off") - val shj = df.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j } + val shj = df.queryExecution.sparkPlan.collect { case j: SortMergeJoin => j } assert(shj.size === 1, "ShuffledHashJoin should be planned when BroadcastHashJoin is turned off") From e78ec1a8fabfe409c92c4904208f53dbdcfcf139 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 17:51:58 -0700 Subject: [PATCH 126/219] [SPARK-9421] Fix null-handling bugs in UnsafeRow.getDouble, getFloat(), and get(ordinal, dataType) UnsafeRow.getDouble and getFloat() return NaN when called on columns that are null, which is inconsistent with the behavior of other row classes (which is to return 0.0). In addition, the generic get(ordinal, dataType) method should always return null for a null literal, but currently it handles nulls by calling the type-specific accessors. This patch addresses both of these issues and adds a regression test. Author: Josh Rosen Closes #7736 from JoshRosen/unsafe-row-null-fixes and squashes the following commits: c8eb2ee [Josh Rosen] Fix test in UnsafeRowConverterSuite 6214682 [Josh Rosen] Fixes to null handling in UnsafeRow --- .../sql/catalyst/expressions/UnsafeRow.java | 14 +++----------- .../expressions/UnsafeRowConverterSuite.scala | 4 ++-- .../org/apache/spark/sql/UnsafeRowSuite.scala | 17 ++++++++++++++++- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 955fb4226fc0e..64a8edc34d681 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -239,7 +239,7 @@ public Object get(int ordinal) { @Override public Object get(int ordinal, DataType dataType) { - if (dataType instanceof NullType) { + if (isNullAt(ordinal) || dataType instanceof NullType) { return null; } else if (dataType instanceof BooleanType) { return getBoolean(ordinal); @@ -313,21 +313,13 @@ public long getLong(int ordinal) { @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - if (isNullAt(ordinal)) { - return Float.NaN; - } else { - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); - } + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - if (isNullAt(ordinal)) { - return Float.NaN; - } else { - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); - } + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } @Override diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 2834b54e8fb2e..b7bc17f89e82f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -146,8 +146,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getShort(3) === 0) assert(createdFromNull.getInt(4) === 0) assert(createdFromNull.getLong(5) === 0) - assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) + assert(createdFromNull.getFloat(6) === 0.0f) + assert(createdFromNull.getDouble(7) === 0.0d) assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) // assert(createdFromNull.get(10) === null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index ad3bb1744cb3c..e72a1bc6c4e20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -67,4 +67,19 @@ class UnsafeRowSuite extends SparkFunSuite { assert(bytesFromArrayBackedRow === bytesFromOffheapRow) } + + test("calling getDouble() and getFloat() on null columns") { + val row = InternalRow.apply(null, null) + val unsafeRow = UnsafeProjection.create(Array[DataType](FloatType, DoubleType)).apply(row) + assert(unsafeRow.getFloat(0) === row.getFloat(0)) + assert(unsafeRow.getDouble(1) === row.getDouble(1)) + } + + test("calling get(ordinal, datatype) on null columns") { + val row = InternalRow.apply(null) + val unsafeRow = UnsafeProjection.create(Array[DataType](NullType)).apply(row) + for (dataType <- DataTypeTestUtils.atomicTypes) { + assert(unsafeRow.get(0, dataType) === null) + } + } } From 3744b7fd42e52011af60cc205fcb4e4b23b35c68 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 28 Jul 2015 19:01:25 -0700 Subject: [PATCH 127/219] [SPARK-9422] [SQL] Remove the placeholder attributes used in the aggregation buffers https://issues.apache.org/jira/browse/SPARK-9422 Author: Yin Huai Closes #7737 from yhuai/removePlaceHolder and squashes the following commits: ec29b44 [Yin Huai] Remove placeholder attributes. --- .../expressions/aggregate/interfaces.scala | 27 ++- .../aggregate/aggregateOperators.scala | 4 +- .../aggregate/sortBasedIterators.scala | 209 +++++++----------- .../spark/sql/execution/aggregate/udaf.scala | 17 +- .../spark/sql/execution/aggregate/utils.scala | 4 +- 5 files changed, 121 insertions(+), 140 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 10bd19c8a840f..9fb7623172e78 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -103,9 +103,30 @@ abstract class AggregateFunction2 final override def foldable: Boolean = false /** - * The offset of this function's buffer in the underlying buffer shared with other functions. + * The offset of this function's start buffer value in the + * underlying shared mutable aggregation buffer. + * For example, we have two aggregate functions `avg(x)` and `avg(y)`, which share + * the same aggregation buffer. In this shared buffer, the position of the first + * buffer value of `avg(x)` will be 0 and the position of the first buffer value of `avg(y)` + * will be 2. */ - var bufferOffset: Int = 0 + var mutableBufferOffset: Int = 0 + + /** + * The offset of this function's start buffer value in the + * underlying shared input aggregation buffer. An input aggregation buffer is used + * when we merge two aggregation buffers and it is basically the immutable one + * (we merge an input aggregation buffer and a mutable aggregation buffer and + * then store the new buffer values to the mutable aggregation buffer). + * Usually, an input aggregation buffer also contain extra elements like grouping + * keys at the beginning. So, mutableBufferOffset and inputBufferOffset are often + * different. + * For example, we have a grouping expression `key``, and two aggregate functions + * `avg(x)` and `avg(y)`. In this shared input aggregation buffer, the position of the first + * buffer value of `avg(x)` will be 1 and the position of the first buffer value of `avg(y)` + * will be 3 (position 0 is used for the value of key`). + */ + var inputBufferOffset: Int = 0 /** The schema of the aggregation buffer. */ def bufferSchema: StructType @@ -176,7 +197,7 @@ abstract class AlgebraicAggregate extends AggregateFunction2 with Serializable w override def initialize(buffer: MutableRow): Unit = { var i = 0 while (i < bufferAttributes.size) { - buffer(i + bufferOffset) = initialValues(i).eval() + buffer(i + mutableBufferOffset) = initialValues(i).eval() i += 1 } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala index 0c9082897f390..98538c462bc89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/aggregateOperators.scala @@ -72,8 +72,10 @@ case class Aggregate2Sort( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => if (aggregateExpressions.length == 0) { - new GroupingIterator( + new FinalSortAggregationIterator( groupingExpressions, + Nil, + Nil, resultExpressions, newMutableProjection, child.output, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala index 1b89edafa8dad..2ca0cb82c1aab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -41,7 +41,8 @@ private[sql] abstract class SortAggregationIterator( /////////////////////////////////////////////////////////////////////////// protected val aggregateFunctions: Array[AggregateFunction2] = { - var bufferOffset = initialBufferOffset + var mutableBufferOffset = 0 + var inputBufferOffset: Int = initialInputBufferOffset val functions = new Array[AggregateFunction2](aggregateExpressions.length) var i = 0 while (i < aggregateExpressions.length) { @@ -54,13 +55,18 @@ private[sql] abstract class SortAggregationIterator( // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. BindReferences.bindReference(func, inputAttributes) - case _ => func + case _ => + // We only need to set inputBufferOffset for aggregate functions with mode + // PartialMerge and Final. + func.inputBufferOffset = inputBufferOffset + inputBufferOffset += func.bufferSchema.length + func } - // Set bufferOffset for this function. It is important that setting bufferOffset - // happens after all potential bindReference operations because bindReference - // will create a new instance of the function. - funcWithBoundReferences.bufferOffset = bufferOffset - bufferOffset += funcWithBoundReferences.bufferSchema.length + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + funcWithBoundReferences.mutableBufferOffset = mutableBufferOffset + mutableBufferOffset += funcWithBoundReferences.bufferSchema.length functions(i) = funcWithBoundReferences i += 1 } @@ -97,25 +103,24 @@ private[sql] abstract class SortAggregationIterator( // The number of elements of the underlying buffer of this operator. // All aggregate functions are sharing this underlying buffer and they find their // buffer values through bufferOffset. - var size = initialBufferOffset - var i = 0 - while (i < aggregateFunctions.length) { - size += aggregateFunctions(i).bufferSchema.length - i += 1 - } - new GenericMutableRow(size) + // var size = 0 + // var i = 0 + // while (i < aggregateFunctions.length) { + // size += aggregateFunctions(i).bufferSchema.length + // i += 1 + // } + new GenericMutableRow(aggregateFunctions.map(_.bufferSchema.length).sum) } protected val joinedRow = new JoinedRow - protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) - // This projection is used to initialize buffer values for all AlgebraicAggregates. protected val algebraicInitialProjection = { - val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + val initExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.initialValues case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } + newMutableProjection(initExpressions, Nil)().target(buffer) } @@ -132,10 +137,6 @@ private[sql] abstract class SortAggregationIterator( // Indicates if we has new group of rows to process. protected var hasNewGroup: Boolean = true - /////////////////////////////////////////////////////////////////////////// - // Private methods - /////////////////////////////////////////////////////////////////////////// - /** Initializes buffer values for all aggregate functions. */ protected def initializeBuffer(): Unit = { algebraicInitialProjection(EmptyRow) @@ -160,6 +161,10 @@ private[sql] abstract class SortAggregationIterator( } } + /////////////////////////////////////////////////////////////////////////// + // Private methods + /////////////////////////////////////////////////////////////////////////// + /** Processes rows in the current group. It will stop when it find a new group. */ private def processCurrentGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -218,10 +223,13 @@ private[sql] abstract class SortAggregationIterator( // Methods that need to be implemented /////////////////////////////////////////////////////////////////////////// - protected def initialBufferOffset: Int + /** The initial input buffer offset for `inputBufferOffset` of an [[AggregateFunction2]]. */ + protected def initialInputBufferOffset: Int + /** The function used to process an input row. */ protected def processRow(row: InternalRow): Unit + /** The function used to generate the result row. */ protected def generateOutput(): InternalRow /////////////////////////////////////////////////////////////////////////// @@ -231,37 +239,6 @@ private[sql] abstract class SortAggregationIterator( initialize() } -/** - * An iterator only used to group input rows according to values of `groupingExpressions`. - * It assumes that input rows are already grouped by values of `groupingExpressions`. - */ -class GroupingIterator( - groupingExpressions: Seq[NamedExpression], - resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), - inputAttributes: Seq[Attribute], - inputIter: Iterator[InternalRow]) - extends SortAggregationIterator( - groupingExpressions, - Nil, - newMutableProjection, - inputAttributes, - inputIter) { - - private val resultProjection = - newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))() - - override protected def initialBufferOffset: Int = 0 - - override protected def processRow(row: InternalRow): Unit = { - // Since we only do grouping, there is nothing to do at here. - } - - override protected def generateOutput(): InternalRow = { - resultProjection(currentGroupingKey) - } -} - /** * An iterator used to do partial aggregations (for those aggregate functions with mode Partial). * It assumes that input rows are already grouped by values of `groupingExpressions`. @@ -291,7 +268,7 @@ class PartialSortAggregationIterator( newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) } - override protected def initialBufferOffset: Int = 0 + override protected def initialInputBufferOffset: Int = 0 override protected def processRow(row: InternalRow): Unit = { // Process all algebraic aggregate functions. @@ -318,11 +295,7 @@ class PartialSortAggregationIterator( * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| * * The format of its internal buffer is: - * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN| - * Every placeholder is for a grouping expression. - * The actual buffers are stored after placeholderN. - * The reason that we have placeholders at here is to make our underlying buffer have the same - * length with a input row. + * |aggregationBuffer1|...|aggregationBufferN| * * The format of its output rows is: * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| @@ -340,33 +313,21 @@ class PartialMergeSortAggregationIterator( inputAttributes, inputIter) { - private val placeholderAttributes = - Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) - // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { - val bufferSchemata = - placeholderAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ - placeholderAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + val mergeInputSchema = + aggregateFunctions.flatMap(_.bufferAttributes) ++ + groupingExpressions.map(_.toAttribute) ++ + aggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } - newMutableProjection(mergeExpressions, bufferSchemata)() + newMutableProjection(mergeExpressions, mergeInputSchema)() } - // This projection is used to extract aggregation buffers from the underlying buffer. - // We need it because the underlying buffer has placeholders at its beginning. - private val extractsBufferValues = { - val expressions = aggregateFunctions.flatMap { - case agg => agg.bufferAttributes - } - - newMutableProjection(expressions, inputAttributes)() - } - - override protected def initialBufferOffset: Int = groupingExpressions.length + override protected def initialInputBufferOffset: Int = groupingExpressions.length override protected def processRow(row: InternalRow): Unit = { // Process all algebraic aggregate functions. @@ -381,7 +342,7 @@ class PartialMergeSortAggregationIterator( override protected def generateOutput(): InternalRow = { // We output grouping expressions and aggregation buffers. - joinedRow(currentGroupingKey, extractsBufferValues(buffer)) + joinedRow(currentGroupingKey, buffer).copy() } } @@ -393,11 +354,7 @@ class PartialMergeSortAggregationIterator( * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| * * The format of its internal buffer is: - * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN| - * Every placeholder is for a grouping expression. - * The actual buffers are stored after placeholderN. - * The reason that we have placeholders at here is to make our underlying buffer have the same - * length with a input row. + * |aggregationBuffer1|...|aggregationBufferN| * * The format of its output rows is represented by the schema of `resultExpressions`. */ @@ -425,27 +382,23 @@ class FinalSortAggregationIterator( newMutableProjection( resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() - private val offsetAttributes = - Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) - // This projection is used to merge buffer values for all AlgebraicAggregates. private val algebraicMergeProjection = { - val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ - offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) - val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + val mergeInputSchema = + aggregateFunctions.flatMap(_.bufferAttributes) ++ + groupingExpressions.map(_.toAttribute) ++ + aggregateFunctions.flatMap(_.cloneBufferAttributes) + val mergeExpressions = aggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } - newMutableProjection(mergeExpressions, bufferSchemata)() + newMutableProjection(mergeExpressions, mergeInputSchema)() } // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { - val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ - offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) + val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp @@ -454,7 +407,7 @@ class FinalSortAggregationIterator( newMutableProjection(evalExpressions, bufferSchemata)() } - override protected def initialBufferOffset: Int = groupingExpressions.length + override protected def initialInputBufferOffset: Int = groupingExpressions.length override def initialize(): Unit = { if (inputIter.hasNext) { @@ -471,7 +424,10 @@ class FinalSortAggregationIterator( // Right now, the buffer only contains initial buffer values. Because // merging two buffers with initial values will generate a row that // still store initial values. We set the currentRow as the copy of the current buffer. - val currentRow = buffer.copy() + // Because input aggregation buffer has initialInputBufferOffset extra values at the + // beginning, we create a dummy row for this part. + val currentRow = + joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() nextGroupingKey = groupGenerator(currentRow).copy() firstRowInNextGroup = currentRow } else { @@ -518,18 +474,15 @@ class FinalSortAggregationIterator( * Final mode. * * The format of its internal buffer is: - * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)| - * The first N placeholders represent slots of grouping expressions. - * Then, next M placeholders represent slots of col1 to colM. + * |aggregationBuffer1|...|aggregationBuffer(N+M)| * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode - * Complete. The reason that we have placeholders at here is to make our underlying buffer - * have the same length with a input row. + * Complete. * * The format of its output rows is represented by the schema of `resultExpressions`. */ class FinalAndCompleteSortAggregationIterator( - override protected val initialBufferOffset: Int, + override protected val initialInputBufferOffset: Int, groupingExpressions: Seq[NamedExpression], finalAggregateExpressions: Seq[AggregateExpression2], finalAggregateAttributes: Seq[Attribute], @@ -561,9 +514,6 @@ class FinalAndCompleteSortAggregationIterator( newMutableProjection(resultExpressions, inputSchema)() } - private val offsetAttributes = - Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) - // All aggregate functions with mode Final. private val finalAggregateFunctions: Array[AggregateFunction2] = { val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) @@ -601,38 +551,38 @@ class FinalAndCompleteSortAggregationIterator( // This projection is used to merge buffer values for all AlgebraicAggregates with mode // Final. private val finalAlgebraicMergeProjection = { - val numCompleteOffsetAttributes = - completeAggregateFunctions.map(_.bufferAttributes.length).sum - val completeOffsetAttributes = - Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)()) - val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp) - - val bufferSchemata = - offsetAttributes ++ finalAggregateFunctions.flatMap(_.bufferAttributes) ++ - completeOffsetAttributes ++ offsetAttributes ++ - finalAggregateFunctions.flatMap(_.cloneBufferAttributes) ++ completeOffsetAttributes + // The first initialInputBufferOffset values of the input aggregation buffer is + // for grouping expressions and distinct columns. + val groupingAttributesAndDistinctColumns = inputAttributes.take(initialInputBufferOffset) + + val completeOffsetExpressions = + Seq.fill(completeAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) + + val mergeInputSchema = + finalAggregateFunctions.flatMap(_.bufferAttributes) ++ + completeAggregateFunctions.flatMap(_.bufferAttributes) ++ + groupingAttributesAndDistinctColumns ++ + finalAggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = - placeholderExpressions ++ finalAggregateFunctions.flatMap { + finalAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.mergeExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } ++ completeOffsetExpressions - - newMutableProjection(mergeExpressions, bufferSchemata)() + newMutableProjection(mergeExpressions, mergeInputSchema)() } // This projection is used to update buffer values for all AlgebraicAggregates with mode // Complete. private val completeAlgebraicUpdateProjection = { - val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum - val finalOffsetAttributes = - Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)()) - val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp) + // We do not touch buffer values of aggregate functions with the Final mode. + val finalOffsetExpressions = + Seq.fill(finalAggregateFunctions.map(_.bufferAttributes.length).sum)(NoOp) val bufferSchema = - offsetAttributes ++ finalOffsetAttributes ++ + finalAggregateFunctions.flatMap(_.bufferAttributes) ++ completeAggregateFunctions.flatMap(_.bufferAttributes) val updateExpressions = - placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: AlgebraicAggregate => ae.updateExpressions case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) } @@ -641,9 +591,7 @@ class FinalAndCompleteSortAggregationIterator( // This projection is used to evaluate all AlgebraicAggregates. private val algebraicEvalProjection = { - val bufferSchemata = - offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++ - offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes) + val bufferSchemata = aggregateFunctions.flatMap(_.bufferAttributes) val evalExpressions = aggregateFunctions.map { case ae: AlgebraicAggregate => ae.evaluateExpression case agg: AggregateFunction2 => NoOp @@ -667,7 +615,10 @@ class FinalAndCompleteSortAggregationIterator( // Right now, the buffer only contains initial buffer values. Because // merging two buffers with initial values will generate a row that // still store initial values. We set the currentRow as the copy of the current buffer. - val currentRow = buffer.copy() + // Because input aggregation buffer has initialInputBufferOffset extra values at the + // beginning, we create a dummy row for this part. + val currentRow = + joinedRow(new GenericInternalRow(initialInputBufferOffset), buffer).copy() nextGroupingKey = groupGenerator(currentRow).copy() firstRowInNextGroup = currentRow } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 073c45ae2f9f2..cc54319171bdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -184,7 +184,7 @@ private[sql] case class ScalaUDAF( bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, - bufferOffset, + inputBufferOffset, null) lazy val mutableAggregateBuffer: MutableAggregationBufferImpl = @@ -192,9 +192,16 @@ private[sql] case class ScalaUDAF( bufferSchema, bufferValuesToCatalystConverters, bufferValuesToScalaConverters, - bufferOffset, + mutableBufferOffset, null) + lazy val evalAggregateBuffer: InputAggregationBuffer = + new InputAggregationBuffer( + bufferSchema, + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + mutableBufferOffset, + null) override def initialize(buffer: MutableRow): Unit = { mutableAggregateBuffer.underlyingBuffer = buffer @@ -217,10 +224,10 @@ private[sql] case class ScalaUDAF( udaf.merge(mutableAggregateBuffer, inputAggregateBuffer) } - override def eval(buffer: InternalRow = null): Any = { - inputAggregateBuffer.underlyingInputBuffer = buffer + override def eval(buffer: InternalRow): Any = { + evalAggregateBuffer.underlyingInputBuffer = buffer - udaf.evaluate(inputAggregateBuffer) + udaf.evaluate(evalAggregateBuffer) } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 5bbe6c162ff4b..6549c87752a7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -292,8 +292,8 @@ object Utils { AggregateExpression2(aggregateFunction, PartialMerge, false) } val partialMergeAggregateAttributes = - partialMergeAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + partialMergeAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes } val partialMergeAggregate = Aggregate2Sort( From 429b2f0df4ef97a3b94cead06a7eb51581eabb18 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 28 Jul 2015 21:37:50 -0700 Subject: [PATCH 128/219] [SPARK-8608][SPARK-8609][SPARK-9083][SQL] reset mutable states of nondeterministic expression before evaluation and fix PullOutNondeterministic We will do local projection for LocalRelation, and thus reuse the same Expression object among multiply evaluations. We should reset the mutable states of Expression before evaluate it. Fix `PullOutNondeterministic` rule to make it work for `Sort`. Also got a chance to cleanup the dataframe test suite. Author: Wenchen Fan Closes #7674 from cloud-fan/show and squashes the following commits: 888934f [Wenchen Fan] fix sort c0e93e8 [Wenchen Fan] local DataFrame with random columns should return same value when call `show` --- .../sql/catalyst/analysis/Analyzer.scala | 15 +- .../sql/catalyst/expressions/Expression.scala | 8 +- .../sql/catalyst/expressions/Projection.scala | 4 +- .../sql/catalyst/expressions/predicates.scala | 2 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 12 +- .../expressions/ExpressionEvalHelper.scala | 2 +- .../org/apache/spark/sql/DataFrameSuite.scala | 153 +++++++++++------- 7 files changed, 120 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a309ee35ee582..a6ea0cc0a83a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -928,12 +928,17 @@ class Analyzer( // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => - val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e => - val ne = e match { - case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")() + val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr => + val leafNondeterministic = expr.collect { + case n: Nondeterministic => n + } + leafNondeterministic.map { e => + val ne = e match { + case n: NamedExpression => n + case _ => Alias(e, "_nondeterministic")() + } + new TreeNodeRef(e) -> ne } - new TreeNodeRef(e) -> ne }.toMap val newPlan = p.transformExpressions { case e => nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 03e36c7871bcf..8fc182607ce68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -201,11 +201,9 @@ trait Nondeterministic extends Expression { private[this] var initialized = false - final def initialize(): Unit = { - if (!initialized) { - initInternal() - initialized = true - } + final def setInitialValues(): Unit = { + initInternal() + initialized = true } protected def initInternal(): Unit diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 27d6ff587ab71..b3beb7e28f208 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -32,7 +32,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { this(expressions.map(BindReferences.bindReference(_, inputSchema))) expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => }) @@ -63,7 +63,7 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu this(expressions.map(BindReferences.bindReference(_, inputSchema))) expressions.foreach(_.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 5bfe1cad24a3e..ab7d3afce8f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -31,7 +31,7 @@ object InterpretedPredicate { def create(expression: Expression): (InternalRow => Boolean) = { expression.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => } (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ed645b618dc9b..4589facb49b76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -153,7 +153,7 @@ class AnalysisSuite extends AnalysisTest { assert(pl(4).dataType == DoubleType) } - test("pull out nondeterministic expressions from unary LogicalPlan") { + test("pull out nondeterministic expressions from RepartitionByExpression") { val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) val projected = Alias(Rand(33), "_nondeterministic")() val expected = @@ -162,4 +162,14 @@ class AnalysisSuite extends AnalysisTest { Project(testRelation.output :+ projected, testRelation))) checkAnalysis(plan, expected) } + + test("pull out nondeterministic expressions from Sort") { + val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation) + val projected = Alias(Rand(33), "_nondeterministic")() + val expected = + Project(testRelation.output, + Sort(Seq(SortOrder(projected.toAttribute, Ascending)), false, + Project(testRelation.output :+ projected, testRelation))) + checkAnalysis(plan, expected) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 0c8611d5ddefa..3c05e5c3b833c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -65,7 +65,7 @@ trait ExpressionEvalHelper { protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = { expression.foreach { - case n: Nondeterministic => n.initialize() + case n: Nondeterministic => n.setInitialValues() case _ => } expression.eval(inputRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3151e071b19ea..97beae2f85c50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -33,33 +33,28 @@ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils} class DataFrameSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ - lazy val ctx = org.apache.spark.sql.test.TestSQLContext - import ctx.implicits._ - - def sqlContext: SQLContext = ctx + lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext + import sqlContext.implicits._ test("analysis error should be eagerly reported") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - - intercept[Exception] { testData.select('nonExistentName) } - intercept[Exception] { - testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) - } - intercept[Exception] { - testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) - } - intercept[Exception] { - testData.groupBy($"abcd").agg(Map("key" -> "sum")) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + intercept[Exception] { testData.select('nonExistentName) } + intercept[Exception] { + testData.groupBy('key).agg(Map("nonExistentName" -> "sum")) + } + intercept[Exception] { + testData.groupBy("nonExistentName").agg(Map("key" -> "sum")) + } + intercept[Exception] { + testData.groupBy($"abcd").agg(Map("key" -> "sum")) + } } // No more eager analysis once the flag is turned off - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false) - testData.select('nonExistentName) - - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") { + testData.select('nonExistentName) + } } test("dataframe toString") { @@ -77,21 +72,18 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("invalid plan toString, debug mode") { - val oldSetting = ctx.conf.dataFrameEagerAnalysis - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true) - // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - ctx.debug() - val badPlan = testData.select('badColumn) + withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") { + sqlContext.debug() - assert(badPlan.toString contains badPlan.queryExecution.toString, - "toString on bad query plans should include the query execution but was:\n" + - badPlan.toString) + val badPlan = testData.select('badColumn) - // Set the flag back to original value before this test. - ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting) + assert(badPlan.toString contains badPlan.queryExecution.toString, + "toString on bad query plans should include the query execution but was:\n" + + badPlan.toString) + } } test("access complex data") { @@ -107,8 +99,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("empty data frame") { - assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(ctx.emptyDataFrame.count() === 0) + assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(sqlContext.emptyDataFrame.count() === 0) } test("head and take") { @@ -344,7 +336,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("replace column using withColumn") { - val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -425,7 +417,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("randomSplit") { val n = 600 - val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -519,7 +511,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("showString: truncate = [true, false]") { val longString = Array.fill(21)("1").mkString - val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF() + val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF() val expectedAnswerForFalse = """+---------------------+ ||_1 | |+---------------------+ @@ -609,21 +601,17 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = ctx.createDataFrame(rowRDD, schema) + val df = sqlContext.createDataFrame(rowRDD, schema) df.rdd.collect() } - test("SPARK-6899") { - val originalValue = ctx.conf.codegenEnabled - ctx.setConf(SQLConf.CODEGEN_ENABLED, true) - try{ + test("SPARK-6899: type should match when using codegen") { + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) - } finally { - ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue) } } @@ -635,14 +623,14 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = ctx.read.json(ctx.sparkContext.makeRDD( + val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = ctx.read.json(ctx.sparkContext.makeRDD( + val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -662,7 +650,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { } test("SPARK-7324 dropDuplicates") { - val testData = ctx.sparkContext.parallelize( + val testData = sqlContext.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -710,49 +698,49 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = ctx.range(0, 10, 1, 15).select("id") + val res1 = sqlContext.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = ctx.range(3, 15, 3, 2).select("id") + val res2 = sqlContext.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = ctx.range(1, -2).select("id") + val res3 = sqlContext.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = ctx.range(1, -2, -2, 6).select("id") + val res4 = sqlContext.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = ctx.range(-3, -8, -2, 1).select("id") + val res5 = sqlContext.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = ctx.range(-8, -4, 2, 1).select("id") + val res6 = sqlContext.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = ctx.range(-10, -9, -20, 1).select("id") + val res7 = sqlContext.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = ctx.range(10).select("id") + val res10 = sqlContext.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = ctx.range(-1).select("id") + val res11 = sqlContext.range(-1).select("id") assert(res11.count == 0) } @@ -819,13 +807,13 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { // pass case: parquet table (HadoopFsRelation) df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath) - val pdf = ctx.read.parquet(tempParquetFile.getCanonicalPath) + val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath) pdf.registerTempTable("parquet_base") insertion.write.insertInto("parquet_base") // pass case: json table (InsertableRelation) df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath) - val jdf = ctx.read.json(tempJsonFile.getCanonicalPath) + val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath) jdf.registerTempTable("json_base") insertion.write.mode(SaveMode.Overwrite).insertInto("json_base") @@ -845,11 +833,54 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed.")) // error case: insert into an OneRowRelation - new DataFrame(ctx, OneRowRelation).registerTempTable("one_row") + new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row") val e3 = intercept[AnalysisException] { insertion.write.insertInto("one_row") } assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed.")) } } + + test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + val df = testData.select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33)) + assert(df.showString(5) == df.showString(5)) + } + + test("SPARK-8609: local DataFrame with random columns should return same value after sort") { + // Make sure we can pass this test for both codegen mode and interpreted mode. + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { + checkAnswer(testData.sort(rand(33)), testData.sort(rand(33))) + } + + // We will reuse the same Expression object for LocalRelation. + val df = (1 to 10).map(Tuple1.apply).toDF() + checkAnswer(df.sort(rand(33)), df.sort(rand(33))) + } + + test("SPARK-9083: sort with non-deterministic expressions") { + import org.apache.spark.util.random.XORShiftRandom + + val seed = 33 + val df = (1 to 100).map(Tuple1.apply).toDF("i") + val random = new XORShiftRandom(seed) + val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1) + val actual = df.sort(rand(seed)).collect().map(_.getInt(0)) + assert(expected === actual) + } } From ea49705bd4feb2f25e1b536f0b3ddcfc72a57101 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 21:53:28 -0700 Subject: [PATCH 129/219] [SPARK-9419] ShuffleMemoryManager and MemoryStore should track memory on a per-task, not per-thread, basis Spark's ShuffleMemoryManager and MemoryStore track memory on a per-thread basis, which causes problems in the handful of cases where we have tasks that use multiple threads. In PythonRDD, RRDD, ScriptTransformation, and PipedRDD we consume the input iterator in a separate thread in order to write it to an external process. As a result, these RDD's input iterators are consumed in a different thread than the thread that created them, which can cause problems in our memory allocation tracking. For example, if allocations are performed in one thread but deallocations are performed in a separate thread then memory may be leaked or we may get errors complaining that more memory was allocated than was freed. I think that the right way to fix this is to change our accounting to be performed on a per-task instead of per-thread basis. Note that the current per-thread tracking has caused problems in the past; SPARK-3731 (#2668) fixes a memory leak in PythonRDD that was caused by this issue (that fix is no longer necessary as of this patch). Author: Josh Rosen Closes #7734 from JoshRosen/memory-tracking-fixes and squashes the following commits: b4b1702 [Josh Rosen] Propagate TaskContext to writer threads. 57c9b4e [Josh Rosen] Merge remote-tracking branch 'origin/master' into memory-tracking-fixes ed25d3b [Josh Rosen] Address minor PR review comments 44f6497 [Josh Rosen] Fix long line. 7b0f04b [Josh Rosen] Fix ShuffleMemoryManagerSuite f57f3f2 [Josh Rosen] More thread -> task changes fa78ee8 [Josh Rosen] Move Executor's cleanup into Task so that TaskContext is defined when cleanup is performed 5e2f01e [Josh Rosen] Fix capitalization 1b0083b [Josh Rosen] Roll back fix in PySpark, which is no longer necessary 2e1e0f8 [Josh Rosen] Use TaskAttemptIds to track shuffle memory c9e8e54 [Josh Rosen] Use TaskAttemptIds to track unroll memory --- .../apache/spark/api/python/PythonRDD.scala | 6 +- .../scala/org/apache/spark/api/r/RRDD.scala | 2 + .../org/apache/spark/executor/Executor.scala | 4 - .../scala/org/apache/spark/rdd/PipedRDD.scala | 1 + .../org/apache/spark/scheduler/Task.scala | 15 ++- .../spark/shuffle/ShuffleMemoryManager.scala | 88 +++++++++-------- .../apache/spark/storage/MemoryStore.scala | 95 ++++++++++--------- .../shuffle/ShuffleMemoryManagerSuite.scala | 41 +++++--- .../spark/storage/BlockManagerSuite.scala | 84 ++++++++-------- 9 files changed, 184 insertions(+), 152 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 598953ac3bcc8..55e563ee968be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -207,6 +207,7 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { + TaskContext.setTaskContext(context) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index @@ -263,11 +264,6 @@ private[spark] class PythonRDD( if (!worker.isClosed) { Utils.tryLog(worker.shutdownOutput()) } - } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() } } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 23a470d6afcae..1cf2824f862ee 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -112,6 +112,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( partition: Int): Unit = { val env = SparkEnv.get + val taskContext = TaskContext.get() val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt val stream = new BufferedOutputStream(output, bufferSize) @@ -119,6 +120,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( override def run(): Unit = { try { SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) val dataOut = new DataOutputStream(stream) dataOut.writeInt(partition) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e76664f1bd7b0..7bc7fce7ae8dd 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -313,10 +313,6 @@ private[spark] class Executor( } } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() runningTasks.remove(taskId) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index defdabf95ac4b..3bb9998e1db44 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -133,6 +133,7 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { override def run() { + TaskContext.setTaskContext(context) val out = new PrintWriter(proc.getOutputStream) // scalastyle:off println diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d11a00956a9a9..1978305cfefbd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{TaskContextImpl, TaskContext} +import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -86,7 +86,18 @@ private[spark] abstract class Task[T]( (runTask(context), context.collectAccumulators()) } finally { context.markTaskCompleted() - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for shuffles + SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() + } + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + } + } finally { + TaskContext.unset() + } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 3bcc7178a3d8b..f038b722957b8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,95 +19,101 @@ package org.apache.spark.shuffle import scala.collection.mutable -import org.apache.spark.{Logging, SparkException, SparkConf} +import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** - * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory * from this pool and release it as it spills data out. When a task ends, all its memory will be * released by the Executor. * - * This class tries to ensure that each thread gets a reasonable share of memory, instead of some - * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * This class tries to ensure that each task gets a reasonable share of memory, instead of some + * task ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. */ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * Try to acquire up to numBytes memory for the current task, and return the number of bytes * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active threads) before it is forced to spill. This can - * happen if the number of threads increases but an older thread had a lot of memory already. + * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active tasks) before it is forced to spill. This can + * happen if the number of tasks increases but an older task had a lot of memory already. */ def tryToAcquire(numBytes: Long): Long = synchronized { - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - // Add this thread to the threadMemory map just so we can keep an accurate count of the number - // of active threads, to let other threads ramp down their memory in calls to tryToAcquire - if (!threadMemory.contains(threadId)) { - threadMemory(threadId) = 0L - notifyAll() // Will later cause waiting threads to wake up and check numThreads again + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire + if (!taskMemory.contains(taskAttemptId)) { + taskMemory(taskAttemptId) = 0L + notifyAll() // Will later cause waiting tasks to wake up and check numThreads again } // Keep looping until we're either sure that we don't want to grant this request (because this - // thread would have more than 1 / numActiveThreads of the memory) or we have enough free - // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). while (true) { - val numActiveThreads = threadMemory.keys.size - val curMem = threadMemory(threadId) - val freeMemory = maxMemory - threadMemory.values.sum + val numActiveTasks = taskMemory.keys.size + val curMem = taskMemory(taskAttemptId) + val freeMemory = maxMemory - taskMemory.values.sum - // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads; + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem)) + val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - if (curMem < maxMemory / (2 * numActiveThreads)) { - // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; - // if we can't give it this much now, wait for other threads to free up memory - // (this happens if older threads allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + if (curMem < maxMemory / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } else { - logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + logInfo( + s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { // Only give it as much memory as is free, which might be none if it reached 1 / numThreads val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } } 0L // Never reached } - /** Release numBytes bytes for the current thread. */ + /** Release numBytes bytes for the current task. */ def release(numBytes: Long): Unit = synchronized { - val threadId = Thread.currentThread().getId - val curMem = threadMemory.getOrElse(threadId, 0L) + val taskAttemptId = currentTaskAttemptId() + val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") } - threadMemory(threadId) -= numBytes + taskMemory(taskAttemptId) -= numBytes notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } - /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisThread(): Unit = synchronized { - val threadId = Thread.currentThread().getId - threadMemory.remove(threadId) + /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisTask(): Unit = synchronized { + val taskAttemptId = currentTaskAttemptId() + taskMemory.remove(taskAttemptId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index ed609772e6979..6f27f00307f8c 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.TaskContext import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -43,11 +44,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Ensure only one thread is putting, and if necessary, dropping blocks at any given time private val accountingLock = new Object - // A mapping from thread ID to amount of memory used for unrolling a block (in bytes) + // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `accountingLock` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. - // Pending unroll memory refers to the intermediate memory occupied by a thread + // Pending unroll memory refers to the intermediate memory occupied by a task // after the unroll but before the actual putting of the block in the cache. // This chunk of memory is expected to be released *as soon as* we finish // caching the corresponding block as opposed to until after the task finishes. @@ -250,21 +251,21 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var elementsUnrolled = 0 // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true - // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. + // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing. val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory val memoryCheckPeriod = 16 - // Memory currently reserved by this thread for this particular unrolling operation + // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this thread, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisThread + // Previous unroll memory held by this task, for releasing later (only at the very end) + val previousMemoryReserved = currentUnrollMemoryForThisTask // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -283,7 +284,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { - if (!reserveUnrollMemoryForThisThread(amountToRequest)) { + if (!reserveUnrollMemoryForThisTask(amountToRequest)) { // If the first request is not granted, try again after ensuring free space // If there is still not enough space, give up and drop the partition val spaceToEnsure = maxUnrollMemory - currentUnrollMemory @@ -291,7 +292,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val result = ensureFreeSpace(blockId, spaceToEnsure) droppedBlocks ++= result.droppedBlocks } - keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest) + keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) } } // New threshold is currentSize * memoryGrowthFactor @@ -317,9 +318,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // later when the task finishes. if (keepUnrolling) { accountingLock.synchronized { - val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved - releaseUnrollMemoryForThisThread(amountToRelease) - reservePendingUnrollMemoryForThisThread(amountToRelease) + val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved + releaseUnrollMemoryForThisTask(amountToRelease) + reservePendingUnrollMemoryForThisTask(amountToRelease) } } } @@ -397,7 +398,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } // Release the unroll memory used because we no longer need the underlying Array - releasePendingUnrollMemoryForThisThread() + releasePendingUnrollMemoryForThisTask() } ResultWithDroppedBlocks(putSuccess, droppedBlocks) } @@ -427,9 +428,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Take into account the amount of memory currently occupied by unrolling blocks // and minus the pending unroll memory for that block on current thread. - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() val actualFreeMemory = freeMemory - currentUnrollMemory + - pendingUnrollMemoryMap.getOrElse(threadId, 0L) + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) if (actualFreeMemory < space) { val rddToAdd = getRddId(blockIdToAdd) @@ -455,7 +456,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping + // This should never be null as only one task should be dropping // blocks and removing entries. However the check is still here for // future safety. if (entry != null) { @@ -482,79 +483,85 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) entries.synchronized { entries.containsKey(blockId) } } + private def currentTaskAttemptId(): Long = { + // In case this is called on the driver, return an invalid task attempt id. + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Reserve additional memory for unrolling blocks used by this thread. + * Reserve additional memory for unrolling blocks used by this task. * Return whether the request is granted. */ - def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { - val threadId = Thread.currentThread().getId - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory + val taskAttemptId = currentTaskAttemptId() + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } granted } } /** - * Release memory used by this thread for unrolling blocks. - * If the amount is not specified, remove the current thread's allocation altogether. + * Release memory used by this task for unrolling blocks. + * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { - val threadId = Thread.currentThread().getId + def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { if (memory < 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap.remove(taskAttemptId) } else { - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory - // If this thread claims no more unroll memory, release it completely - if (unrollMemoryMap(threadId) <= 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory + // If this task claims no more unroll memory, release it completely + if (unrollMemoryMap(taskAttemptId) <= 0) { + unrollMemoryMap.remove(taskAttemptId) } } } } /** - * Reserve the unroll memory of current unroll successful block used by this thread + * Reserve the unroll memory of current unroll successful block used by this task * until actually put the block into memory entry. */ - def reservePendingUnrollMemoryForThisThread(memory: Long): Unit = { - val threadId = Thread.currentThread().getId + def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap(threadId) = pendingUnrollMemoryMap.getOrElse(threadId, 0L) + memory + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } } /** - * Release pending unroll memory of current unroll successful block used by this thread + * Release pending unroll memory of current unroll successful block used by this task */ - def releasePendingUnrollMemoryForThisThread(): Unit = { - val threadId = Thread.currentThread().getId + def releasePendingUnrollMemoryForThisTask(): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap.remove(threadId) + pendingUnrollMemoryMap.remove(taskAttemptId) } } /** - * Return the amount of memory currently occupied for unrolling blocks across all threads. + * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ def currentUnrollMemory: Long = accountingLock.synchronized { unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** - * Return the amount of memory currently occupied for unrolling blocks by this thread. + * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { - unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) + def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { + unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** - * Return the number of threads currently unrolling blocks. + * Return the number of tasks currently unrolling blocks. */ - def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. @@ -566,7 +573,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo( s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 96778c9ebafb1..f495b6a037958 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -17,26 +17,39 @@ package org.apache.spark.shuffle +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.Mockito._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.CountDownLatch -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, TaskContext} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { + + val nextTaskAttemptId = new AtomicInteger() + /** Launch a thread with the given body block and return it. */ private def startThread(name: String)(body: => Unit): Thread = { val thread = new Thread("ShuffleMemorySuite " + name) { override def run() { - body + try { + val taskAttemptId = nextTaskAttemptId.getAndIncrement + val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) + when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) + TaskContext.setTaskContext(mockTaskContext) + body + } finally { + TaskContext.unset() + } } } thread.start() thread } - test("single thread requesting memory") { + test("single task requesting memory") { val manager = new ShuffleMemoryManager(1000L) assert(manager.tryToAcquire(100L) === 100L) @@ -50,7 +63,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(manager.tryToAcquire(300L) === 300L) assert(manager.tryToAcquire(300L) === 200L) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() assert(manager.tryToAcquire(1000L) === 1000L) assert(manager.tryToAcquire(100L) === 0L) } @@ -107,8 +120,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } - test("threads cannot grow past 1 / N") { - // Two threads request 250 bytes first, wait for each other to get it, and then request + test("tasks cannot grow past 1 / N") { + // Two tasks request 250 bytes first, wait for each other to get it, and then request // 500 more; we should only grant 250 bytes to each of them on this second request val manager = new ShuffleMemoryManager(1000L) @@ -158,7 +171,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(state.t2Result2 === 250L) } - test("threads can block to get at least 1 / 2N memory") { + test("tasks can block to get at least 1 / 2N memory") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases 250 bytes, which should then be granted to t2. Further requests // by t2 will return false right away because it now has 1 / 2N of the memory. @@ -224,7 +237,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("releaseMemoryForThisThread") { + test("releaseMemoryForThisTask") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases all its memory. t2 should now be able to grab all the memory. @@ -251,9 +264,9 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other thread blocks for some time otherwise + // sure the other task blocks for some time otherwise Thread.sleep(300) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() } val t2 = startThread("t2") { @@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { t2.join() } - // Both threads should've been able to acquire their memory; the second one will have waited + // Both tasks should've been able to acquire their memory; the second one will have waited // until the first one acquired 1000 bytes and then released all of it state.synchronized { assert(state.t1Result === 1000L, "t1 could not allocate memory") @@ -293,7 +306,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("threads should not be granted a negative size") { + test("tasks should not be granted a negative size") { val manager = new ShuffleMemoryManager(1000L) manager.tryToAcquire(700L) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bcee901f5dd5f..f480fd107a0c2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1004,32 +1004,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(12000) val memoryStore = store.memoryStore assert(memoryStore.currentUnrollMemory === 0) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Reserve - memoryStore.reserveUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 100) - memoryStore.reserveUnrollMemoryForThisThread(200) - assert(memoryStore.currentUnrollMemoryForThisThread === 300) - memoryStore.reserveUnrollMemoryForThisThread(500) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) - memoryStore.reserveUnrollMemoryForThisThread(1000000) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted + memoryStore.reserveUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 100) + memoryStore.reserveUnrollMemoryForThisTask(200) + assert(memoryStore.currentUnrollMemoryForThisTask === 300) + memoryStore.reserveUnrollMemoryForThisTask(500) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) + memoryStore.reserveUnrollMemoryForThisTask(1000000) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted // Release - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 700) - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 600) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 700) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 600) // Reserve again - memoryStore.reserveUnrollMemoryForThisThread(4400) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) - memoryStore.reserveUnrollMemoryForThisThread(20000) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted + memoryStore.reserveUnrollMemoryForThisTask(4400) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) + memoryStore.reserveUnrollMemoryForThisTask(20000) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted // Release again - memoryStore.releaseUnrollMemoryForThisThread(1000) - assert(memoryStore.currentUnrollMemoryForThisThread === 4000) - memoryStore.releaseUnrollMemoryForThisThread() // release all - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + memoryStore.releaseUnrollMemoryForThisTask(1000) + assert(memoryStore.currentUnrollMemoryForThisTask === 4000) + memoryStore.releaseUnrollMemoryForThisTask() // release all + assert(memoryStore.currentUnrollMemoryForThisTask === 0) } /** @@ -1060,24 +1060,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) val memoryStore = store.memoryStore val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with all the space in the world. This should succeed and return an array. var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) - memoryStore.releasePendingUnrollMemoryForThisThread() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll with not enough space. This should succeed after kicking out someBlock1. store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) droppedBlocks.clear() - memoryStore.releasePendingUnrollMemoryForThisThread() + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. @@ -1085,7 +1085,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks) verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock2")) droppedBlocks.clear() @@ -1099,7 +1099,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with plenty of space. This should succeed and cache both blocks. val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) @@ -1110,7 +1110,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(result2.size > 0) assert(result1.data.isLeft) // unroll did not drop this block to disk assert(result2.data.isLeft) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Re-put these two blocks so block manager knows about them too. Otherwise, block manager // would not know how to drop them from memory later. @@ -1126,7 +1126,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b1")) assert(memoryStore.contains("b2")) assert(memoryStore.contains("b3")) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.remove("b3") store.putIterator("b3", smallIterator, memOnly) @@ -1138,7 +1138,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } /** @@ -1153,7 +1153,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) store.putIterator("b1", smallIterator, memAndDisk) store.putIterator("b2", smallIterator, memAndDisk) @@ -1170,7 +1170,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!diskStore.contains("b3")) memoryStore.remove("b3") store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll huge block with not enough space. This should fail and drop the new block to disk // directly in addition to kicking out b2 in the process. Memory store should contain only @@ -1186,7 +1186,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(diskStore.contains("b2")) assert(!diskStore.contains("b3")) assert(diskStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } test("multiple unrolls by the same thread") { @@ -1195,32 +1195,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // All unroll memory used is released because unrollSafely returned an array memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll memory is not released because unrollSafely returned an iterator // that still depends on the underlying vector used in the process memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB3 > 0) // The unroll memory owned by this thread builds on top of its value after the previous unrolls memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) // ... but only to a certain extent (until we run out of free space to grant new unroll memory) memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) From 6309b93467b06f27cd76d4662b51b47de100c677 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 28 Jul 2015 22:38:28 -0700 Subject: [PATCH 130/219] [SPARK-9398] [SQL] Datetime cleanup JIRA: https://issues.apache.org/jira/browse/SPARK-9398 Author: Yijie Shen Closes #7725 from yjshen/date_null_check and squashes the following commits: b4eade1 [Yijie Shen] inline daysToMonthEnd d09acc1 [Yijie Shen] implement getLastDayOfMonth to avoid repeated evaluation d857ec3 [Yijie Shen] add null check in DateExpressionSuite --- .../expressions/datetimeFunctions.scala | 45 ++++++------------- .../sql/catalyst/util/DateTimeUtils.scala | 43 +++++++++++++----- .../expressions/DateExpressionsSuite.scala | 2 + 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index c37afc13f2d17..efecb771f2f5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -74,9 +74,7 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getHours($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") } } @@ -92,9 +90,7 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getMinutes($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") } } @@ -110,9 +106,7 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getSeconds($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") } } @@ -128,9 +122,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getDayInYear($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") } } @@ -147,9 +139,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, c => - s"""$dtu.getYear($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") } } @@ -165,9 +155,7 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getQuarter($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") } } @@ -183,9 +171,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getMonth($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") } } @@ -201,9 +187,7 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (c) => - s"""$dtu.getDayOfMonth($c)""" - ) + defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") } } @@ -226,7 +210,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, (time) => { + nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") ctx.addMutableState(cal, c, @@ -250,8 +234,6 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType) - override def prettyName: String = "date_format" - override protected def nullSafeEval(timestamp: Any, format: Any): Any = { val sdf = new SimpleDateFormat(format.toString) UTF8String.fromString(sdf.format(new Date(timestamp.asInstanceOf[Long] / 1000))) @@ -264,6 +246,8 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx .format(new java.sql.Date($timestamp / 1000)))""" }) } + + override def prettyName: String = "date_format" } /** @@ -277,15 +261,12 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC override def dataType: DataType = DateType override def nullSafeEval(date: Any): Any = { - val days = date.asInstanceOf[Int] - DateTimeUtils.getLastDayOfMonth(days) + DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - defineCodeGen(ctx, ev, (sd) => { - s"$dtu.getLastDayOfMonth($sd)" - }) + defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") } override def prettyName: String = "last_day" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 8b0b80c26db17..93966a503c27c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -600,23 +600,44 @@ object DateTimeUtils { startDate + 1 + ((dayOfWeek - 1 - startDate) % 7 + 7) % 7 } - /** - * number of days in a non-leap year. - */ - private[this] val daysInNormalYear = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31) - /** * Returns last day of the month for the given date. The date is expressed in days * since 1.1.1970. */ def getLastDayOfMonth(date: Int): Int = { - val dayOfMonth = getDayOfMonth(date) - val month = getMonth(date) - if (month == 2 && isLeapYear(getYear(date))) { - date + daysInNormalYear(month - 1) + 1 - dayOfMonth + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear > 31 && dayInYear <= 60) { + return date + (60 - dayInYear) + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + val lastDayOfMonthInYear = if (dayInYear <= 31) { + 31 + } else if (dayInYear <= 59) { + 59 + } else if (dayInYear <= 90) { + 90 + } else if (dayInYear <= 120) { + 120 + } else if (dayInYear <= 151) { + 151 + } else if (dayInYear <= 181) { + 181 + } else if (dayInYear <= 212) { + 212 + } else if (dayInYear <= 243) { + 243 + } else if (dayInYear <= 273) { + 273 + } else if (dayInYear <= 304) { + 304 + } else if (dayInYear <= 334) { + 334 } else { - date + daysInNormalYear(month - 1) - dayOfMonth + 365 } + date + (lastDayOfMonthInYear - dayInYear) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 30c5769424bd7..aca8d6eb3500c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -106,6 +106,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + checkEvaluation(DayOfYear(Literal.create(null, DateType)), null) } test("Year") { @@ -274,6 +275,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(LastDay(Literal(Date.valueOf("2015-12-05"))), Date.valueOf("2015-12-31")) checkEvaluation(LastDay(Literal(Date.valueOf("2016-01-06"))), Date.valueOf("2016-01-31")) checkEvaluation(LastDay(Literal(Date.valueOf("2016-02-07"))), Date.valueOf("2016-02-29")) + checkEvaluation(LastDay(Literal.create(null, DateType)), null) } test("next_day") { From 15667a0afa5fb17f4cc6fbf32b2ddb573630f20a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 28 Jul 2015 22:51:08 -0700 Subject: [PATCH 131/219] [SPARK-9281] [SQL] use decimal or double when parsing SQL Right now, we use double to parse all the float number in SQL. When it's used in expression together with DecimalType, it will turn the decimal into double as well. Also it will loss some precision when using double. This PR change to parse float number to decimal or double, based on it's using scientific notation or not, see https://msdn.microsoft.com/en-us/library/ms179899.aspx This is a break change, should we doc it somewhere? Author: Davies Liu Closes #7642 from davies/parse_decimal and squashes the following commits: 1f576d9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into parse_decimal 5e142b6 [Davies Liu] fix scala style eca99de [Davies Liu] fix tests 2afe702 [Davies Liu] Merge branch 'master' of github.com:apache/spark into parse_decimal f4a320b [Davies Liu] Update SqlParser.scala 1c48e34 [Davies Liu] use decimal or double when parsing SQL --- .../apache/spark/sql/catalyst/SqlParser.scala | 14 +++++- .../catalyst/analysis/HiveTypeCoercion.scala | 50 ++++++++++++------- .../sql/catalyst/analysis/AnalysisSuite.scala | 4 +- .../spark/sql/MathExpressionsSuite.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 14 +++--- .../org/apache/spark/sql/json/JsonSuite.scala | 14 +++--- 6 files changed, 62 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b423f0fa04f69..e5f115f74bf3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -332,8 +332,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val numericLiteral: Parser[Literal] = ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } | sign.? ~ unsignedFloat ^^ { - // TODO(davies): some precisions may loss, we should create decimal literal - case s ~ f => Literal(BigDecimal(s.getOrElse("") + f).doubleValue()) + case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) } ) @@ -420,6 +419,17 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { } } + private def toDecimalOrDouble(value: String): Any = { + val decimal = BigDecimal(value) + // follow the behavior in MS SQL Server + // https://msdn.microsoft.com/en-us/library/ms179899.aspx + if (value.contains('E') || value.contains('e')) { + decimal.doubleValue() + } else { + decimal.underlying() + } + } + protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) | ident <~ "." ~ "*" ^^ { case tableName => UnresolvedStar(Option(tableName)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e0527503442f0..ecc48986e35d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -109,13 +109,35 @@ object HiveTypeCoercion { * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. */ - private def findTightestCommonType(types: Seq[DataType]) = { + private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { case None => None case Some(d) => findTightestCommonTypeOfTwo(d, c) }) } + private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match { + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (t: FractionalType, d: DecimalType) => + Some(DoubleType) + case (d: DecimalType, t: FractionalType) => + Some(DoubleType) + case _ => + findTightestCommonTypeToString(t1, t2) + } + + private def findWiderCommonType(types: Seq[DataType]) = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findWiderTypeForTwo(d, c) + case None => None + }) + } + /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to * instances higher in the query tree. @@ -182,20 +204,7 @@ object HiveTypeCoercion { val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => - (lhs.dataType, rhs.dataType) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (t: FractionalType, d: DecimalType) => - Some(DoubleType) - case (d: DecimalType, t: FractionalType) => - Some(DoubleType) - case _ => - findTightestCommonTypeToString(lhs.dataType, rhs.dataType) - } + findWiderTypeForTwo(lhs.dataType, rhs.dataType) case other => None } @@ -236,8 +245,13 @@ object HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), r) => - a.makeCopy(Array(Cast(left, DoubleType), r)) + case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) => + a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right)) + case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) => + a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT))) + + case a @ BinaryArithmetic(left @ StringType(), right) => + a.makeCopy(Array(Cast(left, DoubleType), right)) case a @ BinaryArithmetic(left, right @ StringType()) => a.makeCopy(Array(left, Cast(right, DoubleType))) @@ -543,7 +557,7 @@ object HiveTypeCoercion { // compatible with every child column. case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) - findTightestCommonTypeAndPromoteToString(types) match { + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 4589facb49b76..221b4e92f086c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -145,11 +145,11 @@ class AnalysisSuite extends AnalysisTest { 'e / 'e as 'div5)) val pl = plan.asInstanceOf[Project].projectList - // StringType will be promoted into Double assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DoubleType) + // StringType will be promoted into Decimal(38, 18) + assert(pl(3).dataType == DecimalType(38, 29)) assert(pl(4).dataType == DoubleType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 21256704a5b16..8cf2ef5957d8d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -216,7 +216,8 @@ class MathExpressionsSuite extends QueryTest { checkAnswer( ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " + s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"), - Seq(Row(0.0, 0.0, 0.0, 3.0, 3.1, 3.14, 3.142)) + Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), + BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 42724ed766af5..d13dde1cdc8b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -368,7 +368,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(1)) checkAnswer( sql("SELECT COALESCE(null, 1, 1.5)"), - Row(1.toDouble)) + Row(BigDecimal(1))) checkAnswer( sql("SELECT COALESCE(null, null, null)"), Row(null)) @@ -1234,19 +1234,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Floating point number format") { checkAnswer( - sql("SELECT 0.3"), Row(0.3) + sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying()) ) checkAnswer( - sql("SELECT -0.8"), Row(-0.8) + sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying()) ) checkAnswer( - sql("SELECT .5"), Row(0.5) + sql("SELECT .5"), Row(BigDecimal(0.5)) ) checkAnswer( - sql("SELECT -.18"), Row(-0.18) + sql("SELECT -.18"), Row(BigDecimal(-0.18)) ) } @@ -1279,11 +1279,11 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { ) checkAnswer( - sql("SELECT -5.2"), Row(-5.2) + sql("SELECT -5.2"), Row(BigDecimal(-5.2)) ) checkAnswer( - sql("SELECT +6.8"), Row(6.8) + sql("SELECT +6.8"), Row(BigDecimal(6.8)) ) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 3ac312d6f4c50..f19f22fca7d54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -422,14 +422,14 @@ class JsonSuite extends QueryTest with TestJsonData { Row(-89) :: Row(21474836370L) :: Row(21474836470L) :: Nil ) - // Widening to DoubleType + // Widening to DecimalType checkAnswer( sql("select num_num_2 + 1.3 from jsonTable where num_num_2 > 1.1"), - Row(21474836472.2) :: - Row(92233720368547758071.3) :: Nil + Row(BigDecimal("21474836472.2")) :: + Row(BigDecimal("92233720368547758071.3")) :: Nil ) - // Widening to DoubleType + // Widening to Double checkAnswer( sql("select num_num_3 + 1.2 from jsonTable where num_num_3 > 1.1"), Row(101.2) :: Row(21474836471.2) :: Nil @@ -438,13 +438,13 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - Row(92233720368547758071.2) + Row(BigDecimal("92233720368547758071.2")) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) + Row(new java.math.BigDecimal("92233720368547758071.2")) ) // String and Boolean conflict: resolve the type as string. @@ -503,7 +503,7 @@ class JsonSuite extends QueryTest with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 13"), - Row(14.3) :: Row(92233720368547758071.2) :: Nil + Row(BigDecimal("14.3")) :: Row(BigDecimal("92233720368547758071.2")) :: Nil ) } From 708794e8aae2c66bd291bab4f12117c33b57840c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 29 Jul 2015 00:08:45 -0700 Subject: [PATCH 132/219] [SPARK-9251][SQL] do not order by expressions which still need evaluation as an offline discussion with rxin , it's weird to be computing stuff while doing sorting, we should only order by bound reference during execution. Author: Wenchen Fan Closes #7593 from cloud-fan/sort and squashes the following commits: 7b1bef7 [Wenchen Fan] add test daf206d [Wenchen Fan] add more comments 289bee0 [Wenchen Fan] do not order by expressions which still need evaluation --- .../sql/catalyst/analysis/Analyzer.scala | 58 +++++++++++++++++++ .../sql/catalyst/expressions/random.scala | 4 +- .../plans/logical/basicOperators.scala | 13 +++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 36 ++++++++++-- .../scala/org/apache/spark/sql/TestData.scala | 2 - 5 files changed, 101 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a6ea0cc0a83a8..265f3d1e41765 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -79,6 +79,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: + RemoveEvaluationFromSort :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -947,6 +948,63 @@ class Analyzer( Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } } + + /** + * Removes all still-need-evaluate ordering expressions from sort and use an inner project to + * materialize them, finally use a outer project to project them away to keep the result same. + * Then we can make sure we only sort by [[AttributeReference]]s. + * + * As an example, + * {{{ + * Sort('a, 'b + 1, + * Relation('a, 'b)) + * }}} + * will be turned into: + * {{{ + * Project('a, 'b, + * Sort('a, '_sortCondition, + * Project('a, 'b, ('b + 1).as("_sortCondition"), + * Relation('a, 'b)))) + * }}} + */ + object RemoveEvaluationFromSort extends Rule[LogicalPlan] { + private def hasAlias(expr: Expression) = { + expr.find { + case a: Alias => true + case _ => false + }.isDefined + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // The ordering expressions have no effect to the output schema of `Sort`, + // so `Alias`s in ordering expressions are unnecessary and we should remove them. + case s @ Sort(ordering, _, _) if ordering.exists(hasAlias) => + val newOrdering = ordering.map(_.transformUp { + case Alias(child, _) => child + }.asInstanceOf[SortOrder]) + s.copy(order = newOrdering) + + case s @ Sort(ordering, global, child) + if s.expressions.forall(_.resolved) && s.childrenResolved && !s.hasNoEvaluation => + + val (ref, needEval) = ordering.partition(_.child.isInstanceOf[AttributeReference]) + + val namedExpr = needEval.map(_.child match { + case n: NamedExpression => n + case e => Alias(e, "_sortCondition")() + }) + + val newOrdering = ref ++ needEval.zip(namedExpr).map { case (order, ne) => + order.copy(child = ne.toAttribute) + } + + // Add still-need-evaluate ordering expressions into inner project and then project + // them away after the sort. + Project(child.output, + Sort(newOrdering, global, + Project(child.output ++ namedExpr, child))) + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala index 8f30519697a37..62d3d204ca872 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala @@ -66,7 +66,7 @@ case class Rand(seed: Long) extends RDG { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") + s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextDouble(); @@ -89,7 +89,7 @@ case class Randn(seed: Long) extends RDG { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, - s"$rngTerm = new $className($seed + org.apache.spark.TaskContext.getPartitionId());") + s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") ev.isNull = "false" s""" final ${ctx.javaType(dataType)} ${ev.primitive} = $rngTerm.nextGaussian(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index af68358daf5f1..ad5af19578f33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -33,7 +33,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend }.nonEmpty ) - !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions + expressions.forall(_.resolved) && childrenResolved && !hasSpecialExpressions } } @@ -67,7 +67,7 @@ case class Generate( generator.resolved && childrenResolved && generator.elementTypes.length == generatorOutput.length && - !generatorOutput.exists(!_.resolved) + generatorOutput.forall(_.resolved) } // we don't want the gOutput to be taken as part of the expressions @@ -187,7 +187,7 @@ case class WithWindowDefinition( } /** - * @param order The ordering expressions + * @param order The ordering expressions, should all be [[AttributeReference]] * @param global True means global sorting apply for entire data set, * False means sorting only apply within the partition. * @param child Child logical plan @@ -197,6 +197,11 @@ case class Sort( global: Boolean, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + + def hasNoEvaluation: Boolean = order.forall(_.child.isInstanceOf[AttributeReference]) + + override lazy val resolved: Boolean = + expressions.forall(_.resolved) && childrenResolved && hasNoEvaluation } case class Aggregate( @@ -211,7 +216,7 @@ case class Aggregate( }.nonEmpty ) - !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions + expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions } override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 221b4e92f086c..a86cefe941e8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -165,11 +165,39 @@ class AnalysisSuite extends AnalysisTest { test("pull out nondeterministic expressions from Sort") { val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation) - val projected = Alias(Rand(33), "_nondeterministic")() + val analyzed = caseSensitiveAnalyzer.execute(plan) + analyzed.transform { + case s: Sort if s.expressions.exists(!_.deterministic) => + fail("nondeterministic expressions are not allowed in Sort") + } + } + + test("remove still-need-evaluate ordering expressions from sort") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + + def makeOrder(e: Expression): SortOrder = SortOrder(e, Ascending) + + val noEvalOrdering = makeOrder(a) + val noEvalOrderingWithAlias = makeOrder(Alias(Alias(b, "name1")(), "name2")()) + + val needEvalExpr = Coalesce(Seq(a, Literal("1"))) + val needEvalExpr2 = Coalesce(Seq(a, b)) + val needEvalOrdering = makeOrder(needEvalExpr) + val needEvalOrdering2 = makeOrder(needEvalExpr2) + + val plan = Sort( + Seq(noEvalOrdering, noEvalOrderingWithAlias, needEvalOrdering, needEvalOrdering2), + false, testRelation2) + + val evaluatedOrdering = makeOrder(AttributeReference("_sortCondition", StringType)()) + val materializedExprs = Seq(needEvalExpr, needEvalExpr2).map(e => Alias(e, "_sortCondition")()) + val expected = - Project(testRelation.output, - Sort(Seq(SortOrder(projected.toAttribute, Ascending)), false, - Project(testRelation.output :+ projected, testRelation))) + Project(testRelation2.output, + Sort(Seq(makeOrder(a), makeOrder(b), evaluatedOrdering, evaluatedOrdering), false, + Project(testRelation2.output ++ materializedExprs, testRelation2))) + checkAnalysis(plan, expected) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 207d7a352c7b3..e340f54850bcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.sql.Timestamp - import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test._ From 97906944e133dec13068f16520b6abbcdc79e84f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 09:36:22 -0700 Subject: [PATCH 133/219] [SPARK-9127][SQL] Rand/Randn codegen fails with long seed. Author: Reynold Xin Closes #7747 from rxin/SPARK-9127 and squashes the following commits: e851418 [Reynold Xin] [SPARK-9127][SQL] Rand/Randn codegen fails with long seed. --- .../spark/sql/catalyst/expressions/RandomSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 698c81ba24482..5db992654811a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.DoubleType class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -30,4 +28,9 @@ class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) } + + test("SPARK-9127 codegen with long seed") { + checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001) + checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001) + } } From 069a4c414db4612d7bdb6f5615c1ba36998e5a49 Mon Sep 17 00:00:00 2001 From: Joseph Batchik Date: Wed, 29 Jul 2015 14:02:32 -0500 Subject: [PATCH 134/219] [SPARK-746] [CORE] Added Avro Serialization to Kryo Added a custom Kryo serializer for generic Avro records to reduce the network IO involved during a shuffle. This compresses the schema and allows for users to register their schemas ahead of time to further reduce traffic. Currently Kryo tries to use its default serializer for generic Records, which will include a lot of unneeded data in each record. Author: Joseph Batchik Author: Joseph Batchik Closes #7004 from JDrit/Avro_serialization and squashes the following commits: 8158d51 [Joseph Batchik] updated per feedback c0cf329 [Joseph Batchik] implemented @squito suggestion for SparkEnv dd71efe [Joseph Batchik] fixed bug with serializing 1183a48 [Joseph Batchik] updated codec settings fa9298b [Joseph Batchik] forgot a couple of fixes c5fe794 [Joseph Batchik] implemented @squito suggestion 0f5471a [Joseph Batchik] implemented @squito suggestion to use a codec that is already in spark 6d1925c [Joseph Batchik] fixed to changes suggested by @squito d421bf5 [Joseph Batchik] updated pom to removed versions ab46d10 [Joseph Batchik] Changed Avro dependency to be similar to parent f4ae251 [Joseph Batchik] fixed serialization error in that SparkConf cannot be serialized 2b545cc [Joseph Batchik] started working on fixes for pr 97fba62 [Joseph Batchik] Added a custom Kryo serializer for generic Avro records to reduce the network IO involved during a shuffle. This compresses the schema and allows for users to register their schemas ahead of time to further reduce traffic. --- core/pom.xml | 5 + .../scala/org/apache/spark/SparkConf.scala | 23 ++- .../serializer/GenericAvroSerializer.scala | 150 ++++++++++++++++++ .../spark/serializer/KryoSerializer.scala | 6 + .../GenericAvroSerializerSuite.scala | 84 ++++++++++ 5 files changed, 267 insertions(+), 1 deletion(-) create mode 100644 core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala create mode 100644 core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 95f36eb348698..6fa87ec6a24af 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -34,6 +34,11 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro-mapred + ${avro.mapred.classifier} + com.google.guava guava diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 6cf36fbbd6254..4161792976c7b 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,11 +18,12 @@ package org.apache.spark import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet +import org.apache.avro.{SchemaNormalization, Schema} + import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils @@ -161,6 +162,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + private final val avroNamespace = "avro.schema." + + /** + * Use Kryo serialization and register the given set of Avro schemas so that the generic + * record serializer can decrease network IO + */ + def registerAvroSchemas(schemas: Schema*): SparkConf = { + for (schema <- schemas) { + set(avroNamespace + SchemaNormalization.parsingFingerprint64(schema), schema.toString) + } + this + } + + /** Gets all the avro schemas in the configuration used in the generic Avro record serializer */ + def getAvroSchema: Map[Long, String] = { + getAll.filter { case (k, v) => k.startsWith(avroNamespace) } + .map { case (k, v) => (k.substring(avroNamespace.length).toLong, v) } + .toMap + } + /** Remove a parameter from the configuration */ def remove(key: String): SparkConf = { settings.remove(key) diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala new file mode 100644 index 0000000000000..62f8aae7f2126 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import scala.collection.mutable + +import com.esotericsoftware.kryo.{Kryo, Serializer => KSerializer} +import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} +import org.apache.avro.{Schema, SchemaNormalization} +import org.apache.avro.generic.{GenericData, GenericRecord} +import org.apache.avro.io._ +import org.apache.commons.io.IOUtils + +import org.apache.spark.{SparkException, SparkEnv} +import org.apache.spark.io.CompressionCodec + +/** + * Custom serializer used for generic Avro records. If the user registers the schemas + * ahead of time, then the schema's fingerprint will be sent with each message instead of the actual + * schema, as to reduce network IO. + * Actions like parsing or compressing schemas are computationally expensive so the serializer + * caches all previously seen values as to reduce the amount of work needed to do. + * @param schemas a map where the keys are unique IDs for Avro schemas and the values are the + * string representation of the Avro schema, used to decrease the amount of data + * that needs to be serialized. + */ +private[serializer] class GenericAvroSerializer(schemas: Map[Long, String]) + extends KSerializer[GenericRecord] { + + /** Used to reduce the amount of effort to compress the schema */ + private val compressCache = new mutable.HashMap[Schema, Array[Byte]]() + private val decompressCache = new mutable.HashMap[ByteBuffer, Schema]() + + /** Reuses the same datum reader/writer since the same schema will be used many times */ + private val writerCache = new mutable.HashMap[Schema, DatumWriter[_]]() + private val readerCache = new mutable.HashMap[Schema, DatumReader[_]]() + + /** Fingerprinting is very expensive so this alleviates most of the work */ + private val fingerprintCache = new mutable.HashMap[Schema, Long]() + private val schemaCache = new mutable.HashMap[Long, Schema]() + + // GenericAvroSerializer can't take a SparkConf in the constructor b/c then it would become + // a member of KryoSerializer, which would make KryoSerializer not Serializable. We make + // the codec lazy here just b/c in some unit tests, we use a KryoSerializer w/out having + // the SparkEnv set (note those tests would fail if they tried to serialize avro data). + private lazy val codec = CompressionCodec.createCodec(SparkEnv.get.conf) + + /** + * Used to compress Schemas when they are being sent over the wire. + * The compression results are memoized to reduce the compression time since the + * same schema is compressed many times over + */ + def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, { + val bos = new ByteArrayOutputStream() + val out = codec.compressedOutputStream(bos) + out.write(schema.toString.getBytes("UTF-8")) + out.close() + bos.toByteArray + }) + + /** + * Decompresses the schema into the actual in-memory object. Keeps an internal cache of already + * seen values so to limit the number of times that decompression has to be done. + */ + def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, { + val bis = new ByteArrayInputStream(schemaBytes.array()) + val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis)) + new Schema.Parser().parse(new String(bytes, "UTF-8")) + }) + + /** + * Serializes a record to the given output stream. It caches a lot of the internal data as + * to not redo work + */ + def serializeDatum[R <: GenericRecord](datum: R, output: KryoOutput): Unit = { + val encoder = EncoderFactory.get.binaryEncoder(output, null) + val schema = datum.getSchema + val fingerprint = fingerprintCache.getOrElseUpdate(schema, { + SchemaNormalization.parsingFingerprint64(schema) + }) + schemas.get(fingerprint) match { + case Some(_) => + output.writeBoolean(true) + output.writeLong(fingerprint) + case None => + output.writeBoolean(false) + val compressedSchema = compress(schema) + output.writeInt(compressedSchema.length) + output.writeBytes(compressedSchema) + } + + writerCache.getOrElseUpdate(schema, GenericData.get.createDatumWriter(schema)) + .asInstanceOf[DatumWriter[R]] + .write(datum, encoder) + encoder.flush() + } + + /** + * Deserializes generic records into their in-memory form. There is internal + * state to keep a cache of already seen schemas and datum readers. + */ + def deserializeDatum(input: KryoInput): GenericRecord = { + val schema = { + if (input.readBoolean()) { + val fingerprint = input.readLong() + schemaCache.getOrElseUpdate(fingerprint, { + schemas.get(fingerprint) match { + case Some(s) => new Schema.Parser().parse(s) + case None => + throw new SparkException( + "Error reading attempting to read avro data -- encountered an unknown " + + s"fingerprint: $fingerprint, not sure what schema to use. This could happen " + + "if you registered additional schemas after starting your spark context.") + } + }) + } else { + val length = input.readInt() + decompress(ByteBuffer.wrap(input.readBytes(length))) + } + } + val decoder = DecoderFactory.get.directBinaryDecoder(input, null) + readerCache.getOrElseUpdate(schema, GenericData.get.createDatumReader(schema)) + .asInstanceOf[DatumReader[GenericRecord]] + .read(null, decoder) + } + + override def write(kryo: Kryo, output: KryoOutput, datum: GenericRecord): Unit = + serializeDatum(datum, output) + + override def read(kryo: Kryo, input: KryoInput, datumClass: Class[GenericRecord]): GenericRecord = + deserializeDatum(input) +} diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 7cb6e080533ad..0ff7562e912ca 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,6 +27,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} +import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.{ArrayContainer, BitmapContainer, RoaringArray, RoaringBitmap} import org.apache.spark._ @@ -73,6 +74,8 @@ class KryoSerializer(conf: SparkConf) .split(',') .filter(!_.isEmpty) + private val avroSchemas = conf.getAvroSchema + def newKryoOutput(): KryoOutput = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) def newKryo(): Kryo = { @@ -101,6 +104,9 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) + kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) + kryo.register(classOf[GenericData.Record], new GenericAvroSerializer(avroSchemas)) + try { // scalastyle:off classforname // Use the default classloader when calling the user registrator. diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala new file mode 100644 index 0000000000000..bc9f3708ed69d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.nio.ByteBuffer + +import com.esotericsoftware.kryo.io.{Output, Input} +import org.apache.avro.{SchemaBuilder, Schema} +import org.apache.avro.generic.GenericData.Record + +import org.apache.spark.{SparkFunSuite, SharedSparkContext} + +class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + + val schema : Schema = SchemaBuilder + .record("testRecord").fields() + .requiredString("data") + .endRecord() + val record = new Record(schema) + record.put("data", "test data") + + test("schema compression and decompression") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema)))) + } + + test("record serialization and deserialization") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + + val outputStream = new ByteArrayOutputStream() + val output = new Output(outputStream) + genericSer.serializeDatum(record, output) + output.flush() + output.close() + + val input = new Input(new ByteArrayInputStream(outputStream.toByteArray)) + assert(genericSer.deserializeDatum(input) === record) + } + + test("uses schema fingerprint to decrease message size") { + val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema) + + val output = new Output(new ByteArrayOutputStream()) + + val beginningNormalPosition = output.total() + genericSerFull.serializeDatum(record, output) + output.flush() + val normalLength = output.total - beginningNormalPosition + + conf.registerAvroSchemas(schema) + val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema) + val beginningFingerprintPosition = output.total() + genericSerFinger.serializeDatum(record, output) + val fingerprintLength = output.total - beginningFingerprintPosition + + assert(fingerprintLength < normalLength) + } + + test("caches previously seen schemas") { + val genericSer = new GenericAvroSerializer(conf.getAvroSchema) + val compressedSchema = genericSer.compress(schema) + val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema)) + + assert(compressedSchema.eq(genericSer.compress(schema))) + assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema)))) + } +} From 819be46e5a73f2d19230354ebba30c58538590f5 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Wed, 29 Jul 2015 13:47:37 -0700 Subject: [PATCH 135/219] [SPARK-8977] [STREAMING] Defines the RateEstimator interface, and impements the RateController MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on #7471. - [x] add a test that exercises the publish path from driver to receiver - [ ] remove Serializable from `RateController` and `RateEstimator` Author: Iulian Dragos Author: François Garillot Closes #7600 from dragos/topic/streaming-bp/rate-controller and squashes the following commits: f168c94 [Iulian Dragos] Latest review round. 5125e60 [Iulian Dragos] Fix style. a2eb3b9 [Iulian Dragos] Merge remote-tracking branch 'upstream/master' into topic/streaming-bp/rate-controller 475e346 [Iulian Dragos] Latest round of reviews. e9fb45e [Iulian Dragos] - Add a test for checkpointing - fixed serialization for RateController.executionContext 715437a [Iulian Dragos] Review comments and added a `reset` call in ReceiverTrackerTest. e57c66b [Iulian Dragos] Added a couple of tests for the full scenario from driver to receivers, with several rate updates. b425d32 [Iulian Dragos] Removed DeveloperAPI, removed rateEstimator field, removed Noop rate estimator, changed logic for initialising rate estimator. 238cfc6 [Iulian Dragos] Merge remote-tracking branch 'upstream/master' into topic/streaming-bp/rate-controller 34a389d [Iulian Dragos] Various style changes and a first test for the rate controller. d32ca36 [François Garillot] [SPARK-8977][Streaming] Defines the RateEstimator interface, and implements the ReceiverRateController 8941cf9 [Iulian Dragos] Renames and other nitpicks. 162d9e5 [Iulian Dragos] Use Reflection for accessing truly private `executor` method and use the listener bus to know when receivers have registered (`onStart` is called before receivers have registered, leading to flaky behavior). 210f495 [Iulian Dragos] Revert "Added a few tests that measure the receiver’s rate." 0c51959 [Iulian Dragos] Added a few tests that measure the receiver’s rate. 261a051 [Iulian Dragos] - removed field to hold the current rate limit in rate limiter - made rate limit a Long and default to Long.MaxValue (consequence of the above) - removed custom `waitUntil` and replaced it by `eventually` cd1397d [Iulian Dragos] Add a test for the propagation of a new rate limit from driver to receivers. 6369b30 [Iulian Dragos] Merge pull request #15 from huitseeker/SPARK-8975 d15de42 [François Garillot] [SPARK-8975][Streaming] Adds Ratelimiter unit tests w.r.t. spark.streaming.receiver.maxRate 4721c7d [François Garillot] [SPARK-8975][Streaming] Add a mechanism to send a new rate from the driver to the block generator --- .../streaming/dstream/InputDStream.scala | 7 +- .../dstream/ReceiverInputDStream.scala | 26 ++++- .../streaming/scheduler/JobScheduler.scala | 6 + .../streaming/scheduler/RateController.scala | 90 +++++++++++++++ .../scheduler/rate/RateEstimator.scala | 59 ++++++++++ .../spark/streaming/CheckpointSuite.scala | 28 +++++ .../scheduler/RateControllerSuite.scala | 103 ++++++++++++++++++ .../ReceiverSchedulingPolicySuite.scala | 10 +- .../scheduler/ReceiverTrackerSuite.scala | 41 +++++-- 9 files changed, 355 insertions(+), 15 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index d58c99a8ff321..a6c4cd220e42f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -21,7 +21,9 @@ import scala.reflect.ClassTag import org.apache.spark.SparkContext import org.apache.spark.rdd.RDDOperationScope -import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} +import org.apache.spark.streaming.scheduler.RateController +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils /** @@ -47,6 +49,9 @@ abstract class InputDStream[T: ClassTag] (@transient ssc_ : StreamingContext) /** This is an unique identifier for the input stream. */ val id = ssc.getNewInputStreamId() + // Keep track of the freshest rate for this stream using the rateEstimator + protected[streaming] val rateController: Option[RateController] = None + /** A human-readable name of this InputDStream */ private[streaming] def name: String = { // e.g. FlumePollingDStream -> "Flume polling stream" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index a50f0efc030ce..646a8c3530a62 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,10 +21,11 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId -import org.apache.spark.streaming._ +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.StreamInputInfo +import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo} +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.streaming.util.WriteAheadLogUtils /** @@ -40,6 +41,17 @@ import org.apache.spark.streaming.util.WriteAheadLogUtils abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { + /** + * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker. + */ + override protected[streaming] val rateController: Option[RateController] = { + if (RateController.isBackPressureEnabled(ssc.conf)) { + RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) } + } else { + None + } + } + /** * Gets the receiver object that will be sent to the worker nodes * to receive data. This method needs to defined by any specific implementation @@ -110,4 +122,14 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont } Some(blockRDD) } + + /** + * A RateController that sends the new rate to receivers, via the receiver tracker. + */ + private[streaming] class ReceiverRateController(id: Int, estimator: RateEstimator) + extends RateController(id, estimator) { + override def publish(rate: Long): Unit = + ssc.scheduler.receiverTracker.sendRateUpdate(id, rate) + } } + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 4af9b6d3b56ab..58bdda7794bf2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -66,6 +66,12 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { } eventLoop.start() + // attach rate controllers of input streams to receive batch completion updates + for { + inputDStream <- ssc.graph.getInputStreams + rateController <- inputDStream.rateController + } ssc.addStreamingListener(rateController) + listenerBus.start(ssc.sparkContext) receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala new file mode 100644 index 0000000000000..882ca0676b6ad --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/RateController.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import java.io.ObjectInputStream +import java.util.concurrent.atomic.AtomicLong + +import scala.concurrent.{ExecutionContext, Future} + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.scheduler.rate.RateEstimator +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * A StreamingListener that receives batch completion updates, and maintains + * an estimate of the speed at which this stream should ingest messages, + * given an estimate computation from a `RateEstimator` + */ +private[streaming] abstract class RateController(val streamUID: Int, rateEstimator: RateEstimator) + extends StreamingListener with Serializable { + + init() + + protected def publish(rate: Long): Unit + + @transient + implicit private var executionContext: ExecutionContext = _ + + @transient + private var rateLimit: AtomicLong = _ + + /** + * An initialization method called both from the constructor and Serialization code. + */ + private def init() { + executionContext = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("stream-rate-update")) + rateLimit = new AtomicLong(-1L) + } + + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + init() + } + + /** + * Compute the new rate limit and publish it asynchronously. + */ + private def computeAndPublish(time: Long, elems: Long, workDelay: Long, waitDelay: Long): Unit = + Future[Unit] { + val newRate = rateEstimator.compute(time, elems, workDelay, waitDelay) + newRate.foreach { s => + rateLimit.set(s.toLong) + publish(getLatestRate()) + } + } + + def getLatestRate(): Long = rateLimit.get() + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { + val elements = batchCompleted.batchInfo.streamIdToInputInfo + + for { + processingEnd <- batchCompleted.batchInfo.processingEndTime; + workDelay <- batchCompleted.batchInfo.processingDelay; + waitDelay <- batchCompleted.batchInfo.schedulingDelay; + elems <- elements.get(streamUID).map(_.numRecords) + } computeAndPublish(processingEnd, elems, workDelay, waitDelay) + } +} + +object RateController { + def isBackPressureEnabled(conf: SparkConf): Boolean = + conf.getBoolean("spark.streaming.backpressure.enable", false) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala new file mode 100644 index 0000000000000..a08685119e5d5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import org.apache.spark.SparkConf +import org.apache.spark.SparkException + +/** + * A component that estimates the rate at wich an InputDStream should ingest + * elements, based on updates at every batch completion. + */ +private[streaming] trait RateEstimator extends Serializable { + + /** + * Computes the number of elements the stream attached to this `RateEstimator` + * should ingest per second, given an update on the size and completion + * times of the latest batch. + * + * @param time The timetamp of the current batch interval that just finished + * @param elements The number of elements that were processed in this batch + * @param processingDelay The time in ms that took for the job to complete + * @param schedulingDelay The time in ms that the job spent in the scheduling queue + */ + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] +} + +object RateEstimator { + + /** + * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. + * + * @return None if there is no configured estimator, otherwise an instance of RateEstimator + * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any + * known estimators. + */ + def create(conf: SparkConf): Option[RateEstimator] = + conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator => + throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + } +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index d308ac05a54fe..67c2d900940ab 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -30,8 +30,10 @@ import org.apache.hadoop.io.{IntWritable, Text} import org.apache.hadoop.mapred.TextOutputFormat import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} +import org.apache.spark.streaming.scheduler.{RateLimitInputDStream, ConstantEstimator, SingletonTestRateReceiver} import org.apache.spark.util.{Clock, ManualClock, Utils} /** @@ -391,6 +393,32 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation(input, operation, output, 7) } + test("recovery maintains rate controller") { + ssc = new StreamingContext(conf, batchDuration) + ssc.checkpoint(checkpointDir) + + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + } + SingletonTestRateReceiver.reset() + + val output = new TestOutputStreamWithPartitions(dstream.checkpoint(batchDuration * 2)) + output.register() + runStreams(ssc, 5, 5) + + SingletonTestRateReceiver.reset() + ssc = new StreamingContext(checkpointDir) + ssc.start() + val outputNew = advanceTimeWithRealDelay(ssc, 2) + + eventually(timeout(5.seconds)) { + assert(dstream.getCurrentRateLimit === Some(200)) + } + ssc.stop() + ssc = null + } + // This tests whether file input stream remembers what files were seen before // the master failure and uses them again to process a large window operation. // It also tests whether batches, whose processing was incomplete due to the diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala new file mode 100644 index 0000000000000..921da773f6c11 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/RateControllerSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.scalatest.Matchers._ +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.streaming._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator + +class RateControllerSuite extends TestSuiteBase { + + override def useManualClock: Boolean = false + + test("rate controller publishes updates") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateLimitInputDStream(ssc) + dstream.register() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.publishCalls > 0) + } + } + } + + test("publish rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(200.0))) + } + dstream.register() + SingletonTestRateReceiver.reset() + ssc.start() + + eventually(timeout(10.seconds)) { + assert(dstream.getCurrentRateLimit === Some(200)) + } + } + } + + test("multiple publish rates reach receivers") { + val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(ssc) { ssc => + val rates = Seq(100L, 200L, 300L) + + val dstream = new RateLimitInputDStream(ssc) { + override val rateController = + Some(new ReceiverRateController(id, new ConstantEstimator(rates.map(_.toDouble): _*))) + } + SingletonTestRateReceiver.reset() + dstream.register() + + val observedRates = mutable.HashSet.empty[Long] + ssc.start() + + eventually(timeout(20.seconds)) { + dstream.getCurrentRateLimit.foreach(observedRates += _) + // Long.MaxValue (essentially, no rate limit) is the initial rate limit for any Receiver + observedRates should contain theSameElementsAs (rates :+ Long.MaxValue) + } + } + } +} + +private[streaming] class ConstantEstimator(rates: Double*) extends RateEstimator { + private var idx: Int = 0 + + private def nextRate(): Double = { + val rate = rates(idx) + idx = (idx + 1) % rates.size + rate + } + + def compute( + time: Long, + elements: Long, + processingDelay: Long, + schedulingDelay: Long): Option[Double] = Some(nextRate()) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala index 93f920fdc71f1..0418d776ecc9a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverSchedulingPolicySuite.scala @@ -64,7 +64,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { test("scheduleReceivers: " + "schedule receivers evenly when there are more receivers than executors") { - val receivers = (0 until 6).map(new DummyReceiver(_)) + val receivers = (0 until 6).map(new RateTestReceiver(_)) val executors = (10000 until 10003).map(port => s"localhost:${port}") val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) val numReceiversOnExecutor = mutable.HashMap[String, Int]() @@ -79,7 +79,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { test("scheduleReceivers: " + "schedule receivers evenly when there are more executors than receivers") { - val receivers = (0 until 3).map(new DummyReceiver(_)) + val receivers = (0 until 3).map(new RateTestReceiver(_)) val executors = (10000 until 10006).map(port => s"localhost:${port}") val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) val numReceiversOnExecutor = mutable.HashMap[String, Int]() @@ -94,8 +94,8 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { } test("scheduleReceivers: schedule receivers evenly when the preferredLocations are even") { - val receivers = (0 until 3).map(new DummyReceiver(_)) ++ - (3 until 6).map(new DummyReceiver(_, Some("localhost"))) + val receivers = (0 until 3).map(new RateTestReceiver(_)) ++ + (3 until 6).map(new RateTestReceiver(_, Some("localhost"))) val executors = (10000 until 10003).map(port => s"localhost:${port}") ++ (10003 until 10006).map(port => s"localhost2:${port}") val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, executors) @@ -121,7 +121,7 @@ class ReceiverSchedulingPolicySuite extends SparkFunSuite { } test("scheduleReceivers: return empty scheduled executors if no executors") { - val receivers = (0 until 3).map(new DummyReceiver(_)) + val receivers = (0 until 3).map(new RateTestReceiver(_)) val scheduledExecutors = receiverSchedulingPolicy.scheduleReceivers(receivers, Seq.empty) scheduledExecutors.foreach { case (receiverId, executors) => assert(executors.isEmpty) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index b039233f36316..aff8b53f752fa 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -43,6 +43,7 @@ class ReceiverTrackerSuite extends TestSuiteBase { ssc.addStreamingListener(ReceiverStartedWaiter) ssc.scheduler.listenerBus.start(ssc.sc) + SingletonTestRateReceiver.reset() val newRateLimit = 100L val inputDStream = new RateLimitInputDStream(ssc) @@ -62,36 +63,62 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } -/** An input DStream with a hard-coded receiver that gives access to internals for testing. */ -private class RateLimitInputDStream(@transient ssc_ : StreamingContext) +/** + * An input DStream with a hard-coded receiver that gives access to internals for testing. + * + * @note Make sure to call {{{SingletonDummyReceiver.reset()}}} before using this in a test, + * or otherwise you may get {{{NotSerializableException}}} when trying to serialize + * the receiver. + * @see [[[SingletonDummyReceiver]]]. + */ +private[streaming] class RateLimitInputDStream(@transient ssc_ : StreamingContext) extends ReceiverInputDStream[Int](ssc_) { - override def getReceiver(): DummyReceiver = SingletonDummyReceiver + override def getReceiver(): RateTestReceiver = SingletonTestRateReceiver def getCurrentRateLimit: Option[Long] = { invokeExecutorMethod.getCurrentRateLimit } + @volatile + var publishCalls = 0 + + override val rateController: Option[RateController] = { + Some(new RateController(id, new ConstantEstimator(100.0)) { + override def publish(rate: Long): Unit = { + publishCalls += 1 + } + }) + } + private def invokeExecutorMethod: ReceiverSupervisor = { val c = classOf[Receiver[_]] val ex = c.getDeclaredMethod("executor") ex.setAccessible(true) - ex.invoke(SingletonDummyReceiver).asInstanceOf[ReceiverSupervisor] + ex.invoke(SingletonTestRateReceiver).asInstanceOf[ReceiverSupervisor] } } /** - * A Receiver as an object so we can read its rate limit. + * A Receiver as an object so we can read its rate limit. Make sure to call `reset()` when + * reusing this receiver, otherwise a non-null `executor_` field will prevent it from being + * serialized when receivers are installed on executors. * * @note It's necessary to be a top-level object, or else serialization would create another * one on the executor side and we won't be able to read its rate limit. */ -private object SingletonDummyReceiver extends DummyReceiver(0) +private[streaming] object SingletonTestRateReceiver extends RateTestReceiver(0) { + + /** Reset the object to be usable in another test. */ + def reset(): Unit = { + executor_ = null + } +} /** * Dummy receiver implementation */ -private class DummyReceiver(receiverId: Int, host: Option[String] = None) +private[streaming] class RateTestReceiver(receiverId: Int, host: Option[String] = None) extends Receiver[Int](StorageLevel.MEMORY_ONLY) { setReceiverId(receiverId) From 5340dfaf94a3c54199f8cc3c78e11f61e34d0a67 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 13:49:22 -0700 Subject: [PATCH 136/219] [SPARK-9430][SQL] Rename IntervalType to CalendarIntervalType. We want to introduce a new IntervalType in 1.6 that is based on only the number of microseoncds, so interval can be compared. Renaming the existing IntervalType to CalendarIntervalType so we can do that in the future. Author: Reynold Xin Closes #7745 from rxin/calendarintervaltype and squashes the following commits: 99f64e8 [Reynold Xin] One more line ... 13466c8 [Reynold Xin] Fixed tests. e20f24e [Reynold Xin] [SPARK-9430][SQL] Rename IntervalType to CalendarIntervalType. --- .../expressions/SpecializedGetters.java | 4 +- .../sql/catalyst/expressions/UnsafeRow.java | 10 +- .../expressions/UnsafeRowWriters.java | 4 +- .../org/apache/spark/sql/types/DataTypes.java | 4 +- .../spark/sql/catalyst/InternalRow.scala | 5 +- .../apache/spark/sql/catalyst/SqlParser.scala | 16 +- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 12 +- .../sql/catalyst/expressions/arithmetic.scala | 20 +-- .../expressions/codegen/CodeGenerator.scala | 6 +- .../codegen/GenerateUnsafeProjection.scala | 10 +- .../sql/catalyst/expressions/literals.scala | 2 +- .../spark/sql/types/AbstractDataType.scala | 2 +- ...lType.scala => CalendarIntervalType.scala} | 15 +- .../ExpressionTypeCheckingSuite.scala | 7 +- .../analysis/HiveTypeCoercionSuite.scala | 2 +- .../sql/catalyst/expressions/CastSuite.scala | 9 +- .../spark/sql/execution/basicOperators.scala | 131 --------------- .../spark/sql/execution/datasources/ddl.scala | 2 +- .../org/apache/spark/sql/execution/sort.scala | 159 ++++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 20 +-- .../{Interval.java => CalendarInterval.java} | 24 +-- .../spark/unsafe/types/IntervalSuite.java | 72 ++++---- 23 files changed, 286 insertions(+), 252 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/types/{IntervalType.scala => CalendarIntervalType.scala} (64%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala rename unsafe/src/main/java/org/apache/spark/unsafe/types/{Interval.java => CalendarInterval.java} (87%) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index 5f28d52a94bd7..bc345dcd00e49 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -19,7 +19,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.types.Interval; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; public interface SpecializedGetters { @@ -46,7 +46,7 @@ public interface SpecializedGetters { byte[] getBinary(int ordinal); - Interval getInterval(int ordinal); + CalendarInterval getInterval(int ordinal); InternalRow getStruct(int ordinal, int numFields); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 64a8edc34d681..6d684bac37573 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -29,7 +29,7 @@ import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; -import org.apache.spark.unsafe.types.Interval; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.types.DataTypes.*; @@ -92,7 +92,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { Arrays.asList(new DataType[]{ StringType, BinaryType, - IntervalType + CalendarIntervalType })); _readableFieldTypes.addAll(settableFieldTypes); readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); @@ -265,7 +265,7 @@ public Object get(int ordinal, DataType dataType) { return getBinary(ordinal); } else if (dataType instanceof StringType) { return getUTF8String(ordinal); - } else if (dataType instanceof IntervalType) { + } else if (dataType instanceof CalendarIntervalType) { return getInterval(ordinal); } else if (dataType instanceof StructType) { return getStruct(ordinal, ((StructType) dataType).size()); @@ -350,7 +350,7 @@ public byte[] getBinary(int ordinal) { } @Override - public Interval getInterval(int ordinal) { + public CalendarInterval getInterval(int ordinal) { if (isNullAt(ordinal)) { return null; } else { @@ -359,7 +359,7 @@ public Interval getInterval(int ordinal) { final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); final long microseconds = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); - return new Interval(months, microseconds); + return new CalendarInterval(months, microseconds); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index 32faad374015c..c3259e21c4a78 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -21,7 +21,7 @@ import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; -import org.apache.spark.unsafe.types.Interval; +import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** @@ -131,7 +131,7 @@ public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow i /** Writer for interval type. */ public static class IntervalWriter { - public static int write(UnsafeRow target, int ordinal, int cursor, Interval input) { + public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInterval input) { final long offset = target.getBaseOffset() + cursor; // Write the months and microseconds fields of Interval to the variable length portion. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 5703de42393de..17659d7d960b0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -50,9 +50,9 @@ public class DataTypes { public static final DataType TimestampType = TimestampType$.MODULE$; /** - * Gets the IntervalType object. + * Gets the CalendarIntervalType object. */ - public static final DataType IntervalType = IntervalType$.MODULE$; + public static final DataType CalendarIntervalType = CalendarIntervalType$.MODULE$; /** * Gets the DoubleType object. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index e395a67434fa7..a5999e64ec554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{Interval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as @@ -61,7 +61,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters { override def getDecimal(ordinal: Int): Decimal = getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) - override def getInterval(ordinal: Int): Interval = getAs[Interval](ordinal, IntervalType) + override def getInterval(ordinal: Int): CalendarInterval = + getAs[CalendarInterval](ordinal, CalendarIntervalType) // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index e5f115f74bf3b..f2498861c9573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.Interval +import org.apache.spark.unsafe.types.CalendarInterval /** * A very simple SQL parser. Based loosely on: @@ -365,32 +365,32 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val millisecond: Parser[Long] = integral <~ intervalUnit("millisecond") ^^ { - case num => num.toLong * Interval.MICROS_PER_MILLI + case num => num.toLong * CalendarInterval.MICROS_PER_MILLI } protected lazy val second: Parser[Long] = integral <~ intervalUnit("second") ^^ { - case num => num.toLong * Interval.MICROS_PER_SECOND + case num => num.toLong * CalendarInterval.MICROS_PER_SECOND } protected lazy val minute: Parser[Long] = integral <~ intervalUnit("minute") ^^ { - case num => num.toLong * Interval.MICROS_PER_MINUTE + case num => num.toLong * CalendarInterval.MICROS_PER_MINUTE } protected lazy val hour: Parser[Long] = integral <~ intervalUnit("hour") ^^ { - case num => num.toLong * Interval.MICROS_PER_HOUR + case num => num.toLong * CalendarInterval.MICROS_PER_HOUR } protected lazy val day: Parser[Long] = integral <~ intervalUnit("day") ^^ { - case num => num.toLong * Interval.MICROS_PER_DAY + case num => num.toLong * CalendarInterval.MICROS_PER_DAY } protected lazy val week: Parser[Long] = integral <~ intervalUnit("week") ^^ { - case num => num.toLong * Interval.MICROS_PER_WEEK + case num => num.toLong * CalendarInterval.MICROS_PER_WEEK } protected lazy val intervalLiteral: Parser[Literal] = @@ -406,7 +406,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { val months = Seq(year, month).map(_.getOrElse(0)).sum val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond) .map(_.getOrElse(0L)).sum - Literal.create(new Interval(months, microseconds), IntervalType) + Literal.create(new CalendarInterval(months, microseconds), CalendarIntervalType) } private def toNarrowestIntegerType(value: String): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 8304d4ccd47f7..371681b5d494f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -48,7 +48,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case DoubleType => input.getDouble(ordinal) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) - case IntervalType => input.getInterval(ordinal) + case CalendarIntervalType => input.getInterval(ordinal) case t: StructType => input.getStruct(ordinal, t.size) case _ => input.get(ordinal, dataType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bd8b0177eb00e..c6e8af27667ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{Interval, UTF8String} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import scala.collection.mutable @@ -55,7 +55,7 @@ object Cast { case (_, DateType) => true - case (StringType, IntervalType) => true + case (StringType, CalendarIntervalType) => true case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true @@ -225,7 +225,7 @@ case class Cast(child: Expression, dataType: DataType) // IntervalConverter private[this] def castToInterval(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => Interval.fromString(s.toString)) + buildCast[UTF8String](_, s => CalendarInterval.fromString(s.toString)) case _ => _ => null } @@ -398,7 +398,7 @@ case class Cast(child: Expression, dataType: DataType) case DateType => castToDate(from) case decimal: DecimalType => castToDecimal(from, decimal) case TimestampType => castToTimestamp(from) - case IntervalType => castToInterval(from) + case CalendarIntervalType => castToInterval(from) case BooleanType => castToBoolean(from) case ByteType => castToByte(from) case ShortType => castToShort(from) @@ -438,7 +438,7 @@ case class Cast(child: Expression, dataType: DataType) case DateType => castToDateCode(from, ctx) case decimal: DecimalType => castToDecimalCode(from, decimal) case TimestampType => castToTimestampCode(from, ctx) - case IntervalType => castToIntervalCode(from) + case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) case ByteType => castToByteCode(from) case ShortType => castToShortCode(from) @@ -630,7 +630,7 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castToIntervalCode(from: DataType): CastFunction = from match { case StringType => (c, evPrim, evNull) => - s"$evPrim = Interval.fromString($c.toString());" + s"$evPrim = CalendarInterval.fromString($c.toString());" } private[this] def decimalToTimestampCode(d: String): String = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4ec866475f8b0..6f8f4dd230f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.Interval +import org.apache.spark.unsafe.types.CalendarInterval case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -37,12 +37,12 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") - case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") + case dt: CalendarIntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()") } protected override def nullSafeEval(input: Any): Any = { - if (dataType.isInstanceOf[IntervalType]) { - input.asInstanceOf[Interval].negate() + if (dataType.isInstanceOf[CalendarIntervalType]) { + input.asInstanceOf[CalendarInterval].negate() } else { numeric.negate(input) } @@ -121,8 +121,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (dataType.isInstanceOf[IntervalType]) { - input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval]) + if (dataType.isInstanceOf[CalendarIntervalType]) { + input1.asInstanceOf[CalendarInterval].add(input2.asInstanceOf[CalendarInterval]) } else { numeric.plus(input1, input2) } @@ -134,7 +134,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") - case IntervalType => + case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") @@ -150,8 +150,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti private lazy val numeric = TypeUtils.getNumeric(dataType) protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (dataType.isInstanceOf[IntervalType]) { - input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval]) + if (dataType.isInstanceOf[CalendarIntervalType]) { + input1.asInstanceOf[CalendarInterval].subtract(input2.asInstanceOf[CalendarInterval]) } else { numeric.minus(input1, input2) } @@ -163,7 +163,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti case ByteType | ShortType => defineCodeGen(ctx, ev, (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") - case IntervalType => + case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2f02c90b1d5b3..092f4c9fb0bd2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -108,7 +108,7 @@ class CodeGenContext { case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" case StringType => s"$row.getUTF8String($ordinal)" case BinaryType => s"$row.getBinary($ordinal)" - case IntervalType => s"$row.getInterval($ordinal)" + case CalendarIntervalType => s"$row.getInterval($ordinal)" case t: StructType => s"$row.getStruct($ordinal, ${t.size})" case _ => s"($jt)$row.get($ordinal)" } @@ -150,7 +150,7 @@ class CodeGenContext { case dt: DecimalType => "Decimal" case BinaryType => "byte[]" case StringType => "UTF8String" - case IntervalType => "Interval" + case CalendarIntervalType => "CalendarInterval" case _: StructType => "InternalRow" case _: ArrayType => s"scala.collection.Seq" case _: MapType => s"scala.collection.Map" @@ -293,7 +293,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[UnsafeRow].getName, classOf[UTF8String].getName, classOf[Decimal].getName, - classOf[Interval].getName + classOf[CalendarInterval].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 9a4c00e86a3ec..dc725c28aaa27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -39,7 +39,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case t: AtomicType if !t.isInstanceOf[DecimalType] => true - case _: IntervalType => true + case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true case _ => false @@ -75,7 +75,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" case BinaryType => s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" - case IntervalType => + case CalendarIntervalType => s" + (${exprs(i).isNull} ? 0 : 16)" case _: StructType => s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))" @@ -91,7 +91,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" case BinaryType => s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case IntervalType => + case CalendarIntervalType => s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" case t: StructType => s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" @@ -173,7 +173,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" case BinaryType => s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" - case IntervalType => + case CalendarIntervalType => s" + (${ev.isNull} ? 0 : 16)" case _: StructType => s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" @@ -189,7 +189,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" case BinaryType => s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case IntervalType => + case CalendarIntervalType => s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" case t: StructType => s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 064a1720c36e8..34bad23802ba4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -42,7 +42,7 @@ object Literal { case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType) case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType) case a: Array[Byte] => Literal(a, BinaryType) - case i: Interval => Literal(i, IntervalType) + case i: CalendarInterval => Literal(i, CalendarIntervalType) case null => Literal(null, NullType) case _ => throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 40bf4b299c990..e0667c629486d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -95,7 +95,7 @@ private[sql] object TypeCollection { * Types that include numeric types and interval type. They are only used in unary_minus, * unary_positive, add and subtract operations. */ - val NumericAndInterval = TypeCollection(NumericType, IntervalType) + val NumericAndInterval = TypeCollection(NumericType, CalendarIntervalType) def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala similarity index 64% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala index 87c6e9e6e5e2c..3565f52c21f69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntervalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CalendarIntervalType.scala @@ -22,16 +22,19 @@ import org.apache.spark.annotation.DeveloperApi /** * :: DeveloperApi :: - * The data type representing time intervals. + * The data type representing calendar time intervals. The calendar time interval is stored + * internally in two components: number of months the number of microseconds. * - * Please use the singleton [[DataTypes.IntervalType]]. + * Note that calendar intervals are not comparable. + * + * Please use the singleton [[DataTypes.CalendarIntervalType]]. */ @DeveloperApi -class IntervalType private() extends DataType { +class CalendarIntervalType private() extends DataType { - override def defaultSize: Int = 4096 + override def defaultSize: Int = 16 - private[spark] override def asNullable: IntervalType = this + private[spark] override def asNullable: CalendarIntervalType = this } -case object IntervalType extends IntervalType +case object CalendarIntervalType extends CalendarIntervalType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ad15136ee9a2f..8acd4c685e2bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for unary arithmetic") { - assertError(UnaryMinus('stringField), "type (numeric or interval)") + assertError(UnaryMinus('stringField), "type (numeric or calendarinterval)") assertError(Abs('stringField), "expected to be of type numeric") assertError(BitwiseNot('stringField), "expected to be of type integral") } @@ -78,8 +78,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(MaxOf('intField, 'booleanField)) assertErrorForDifferingTypes(MinOf('intField, 'booleanField)) - assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type") - assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type") + assertError(Add('booleanField, 'booleanField), "accepts (numeric or calendarinterval) type") + assertError(Subtract('booleanField, 'booleanField), + "accepts (numeric or calendarinterval) type") assertError(Multiply('booleanField, 'booleanField), "accepts numeric type") assertError(Divide('booleanField, 'booleanField), "accepts numeric type") assertError(Remainder('booleanField, 'booleanField), "accepts numeric type") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 4454d51b75877..1d9ee5ddf3a5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -116,7 +116,7 @@ class HiveTypeCoercionSuite extends PlanTest { shouldNotCast(IntegerType, MapType) shouldNotCast(IntegerType, StructType) - shouldNotCast(IntervalType, StringType) + shouldNotCast(CalendarIntervalType, StringType) // Don't implicitly cast complex types to string. shouldNotCast(ArrayType(StringType), StringType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 408353cf70a49..0e0213be0f57b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -719,12 +719,13 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("case between string and interval") { - import org.apache.spark.unsafe.types.Interval + import org.apache.spark.unsafe.types.CalendarInterval - checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType), - new Interval(-3, 7 * Interval.MICROS_PER_HOUR)) + checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType), + new CalendarInterval(-3, 7 * CalendarInterval.MICROS_PER_HOUR)) checkEvaluation(Cast(Literal.create( - new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType), + new CalendarInterval(15, -3 * CalendarInterval.MICROS_PER_DAY), CalendarIntervalType), + StringType), "interval 1 years 3 months -3 days") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b02e60dc85cdd..2294a670c735f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -220,137 +220,6 @@ case class TakeOrderedAndProject( override def outputOrdering: Seq[SortOrder] = sortOrder } -/** - * :: DeveloperApi :: - * Performs a sort on-heap. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -@DeveloperApi -case class Sort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - iterator.map(_.copy()).toArray.sorted(ordering).iterator - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * :: DeveloperApi :: - * Performs a sort, spilling to disk as needed. - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - */ -@DeveloperApi -case class ExternalSort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - child.execute().mapPartitions( { iterator => - val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) - sorter.insertAll(iterator.map(r => (r.copy, null))) - val baseIterator = sorter.iterator.map(_._1) - // TODO(marmbrus): The complex type signature below thwarts inference for no reason. - CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) - }, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder -} - -/** - * :: DeveloperApi :: - * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of - * Project Tungsten). - * - * @param global when true performs a global sort of all partitions by shuffling the data first - * if necessary. - * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will - * spill every `frequency` records. - */ -@DeveloperApi -case class UnsafeExternalSort( - sortOrder: Seq[SortOrder], - global: Boolean, - child: SparkPlan, - testSpillFrequency: Int = 0) - extends UnaryNode { - - private[this] val schema: StructType = child.schema - - override def requiredChildDistribution: Seq[Distribution] = - if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") - def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { - val ordering = newOrdering(sortOrder, child.output) - val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) - // Hack until we generate separate comparator implementations for ascending vs. descending - // (or choose to codegen them): - val prefixComparator = { - val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) - if (sortOrder.head.direction == Descending) { - new PrefixComparator { - override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) - } - } else { - comp - } - } - val prefixComputer = { - val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) - new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = prefixComputer(row) - } - } - val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } - sorter.sort(iterator) - } - child.execute().mapPartitions(doSort, preservesPartitioning = true) - } - - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def outputsUnsafeRows: Boolean = true -} - -@DeveloperApi -object UnsafeExternalSort { - /** - * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. - */ - def supportsSchema(schema: StructType): Boolean = { - UnsafeExternalRowSorter.supportsSchema(schema) - } -} - /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e73b3704d4dfe..0cdb407ad57b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -308,7 +308,7 @@ private[sql] object ResolvedDataSource { mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[IntervalType])) { + if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } val clazz: Class[_] = lookupDataSource(provider) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala new file mode 100644 index 0000000000000..f82208868c3e3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions.{Descending, BindReferences, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines various sort operators. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +/** + * Performs a sort on-heap. + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + */ +case class Sort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + iterator.map(_.copy()).toArray.sorted(ordering).iterator + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} + +/** + * Performs a sort, spilling to disk as needed. + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + */ +case class ExternalSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + child.execute().mapPartitions( { iterator => + val ordering = newOrdering(sortOrder, child.output) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) + sorter.insertAll(iterator.map(r => (r.copy(), null))) + val baseIterator = sorter.iterator.map(_._1) + // TODO(marmbrus): The complex type signature below thwarts inference for no reason. + CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop()) + }, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder +} + +/** + * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of + * Project Tungsten). + * + * @param global when true performs a global sort of all partitions by shuffling the data first + * if necessary. + * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will + * spill every `frequency` records. + */ +case class UnsafeExternalSort( + sortOrder: Seq[SortOrder], + global: Boolean, + child: SparkPlan, + testSpillFrequency: Int = 0) + extends UnaryNode { + + private[this] val schema: StructType = child.schema + + override def requiredChildDistribution: Seq[Distribution] = + if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") + def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { + val ordering = newOrdering(sortOrder, child.output) + val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) + // Hack until we generate separate comparator implementations for ascending vs. descending + // (or choose to codegen them): + val prefixComparator = { + val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) + if (sortOrder.head.direction == Descending) { + new PrefixComparator { + override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) + } + } else { + comp + } + } + val prefixComputer = { + val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) + new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = prefixComputer(row) + } + } + val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter.sort(iterator) + } + child.execute().mapPartitions(doSort, preservesPartitioning = true) + } + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder + + override def outputsUnsafeRows: Boolean = true +} + +@DeveloperApi +object UnsafeExternalSort { + /** + * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. + */ + def supportsSchema(schema: StructType): Boolean = { + UnsafeExternalRowSorter.supportsSchema(schema) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d13dde1cdc8b2..535011fe3db5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1577,10 +1577,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-8753: add interval type") { - import org.apache.spark.unsafe.types.Interval + import org.apache.spark.unsafe.types.CalendarInterval val df = sql("select interval 3 years -3 month 7 week 123 microseconds") - checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 ))) withTempPath(f => { // Currently we don't yet support saving out values of interval data type. val e = intercept[AnalysisException] { @@ -1602,20 +1602,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-8945: add and subtract expressions for interval type") { - import org.apache.spark.unsafe.types.Interval - import org.apache.spark.unsafe.types.Interval.MICROS_PER_WEEK + import org.apache.spark.unsafe.types.CalendarInterval + import org.apache.spark.unsafe.types.CalendarInterval.MICROS_PER_WEEK val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i") - checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123))) - checkAnswer(df.select(df("i") + new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123))) + checkAnswer(df.select(df("i") + new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123))) - checkAnswer(df.select(df("i") - new Interval(2, 123)), - Row(new Interval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123))) + checkAnswer(df.select(df("i") - new CalendarInterval(2, 123)), + Row(new CalendarInterval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123))) // unary minus checkAnswer(df.select(-df("i")), - Row(new Interval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) + Row(new CalendarInterval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123)))) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java similarity index 87% rename from unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java rename to unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 71b1a85a818ea..92a5e4f86f234 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -24,7 +24,7 @@ /** * The internal representation of interval type. */ -public final class Interval implements Serializable { +public final class CalendarInterval implements Serializable { public static final long MICROS_PER_MILLI = 1000L; public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000; public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60; @@ -58,7 +58,7 @@ private static long toLong(String s) { } } - public static Interval fromString(String s) { + public static CalendarInterval fromString(String s) { if (s == null) { return null; } @@ -75,40 +75,40 @@ public static Interval fromString(String s) { microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; microseconds += toLong(m.group(9)); - return new Interval((int) months, microseconds); + return new CalendarInterval((int) months, microseconds); } } public final int months; public final long microseconds; - public Interval(int months, long microseconds) { + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; } - public Interval add(Interval that) { + public CalendarInterval add(CalendarInterval that) { int months = this.months + that.months; long microseconds = this.microseconds + that.microseconds; - return new Interval(months, microseconds); + return new CalendarInterval(months, microseconds); } - public Interval subtract(Interval that) { + public CalendarInterval subtract(CalendarInterval that) { int months = this.months - that.months; long microseconds = this.microseconds - that.microseconds; - return new Interval(months, microseconds); + return new CalendarInterval(months, microseconds); } - public Interval negate() { - return new Interval(-this.months, -this.microseconds); + public CalendarInterval negate() { + return new CalendarInterval(-this.months, -this.microseconds); } @Override public boolean equals(Object other) { if (this == other) return true; - if (other == null || !(other instanceof Interval)) return false; + if (other == null || !(other instanceof CalendarInterval)) return false; - Interval o = (Interval) other; + CalendarInterval o = (CalendarInterval) other; return this.months == o.months && this.microseconds == o.microseconds; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index d29517cda66a3..e6733a7aae6f5 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -20,16 +20,16 @@ import org.junit.Test; import static junit.framework.Assert.*; -import static org.apache.spark.unsafe.types.Interval.*; +import static org.apache.spark.unsafe.types.CalendarInterval.*; public class IntervalSuite { @Test public void equalsTest() { - Interval i1 = new Interval(3, 123); - Interval i2 = new Interval(3, 321); - Interval i3 = new Interval(1, 123); - Interval i4 = new Interval(3, 123); + CalendarInterval i1 = new CalendarInterval(3, 123); + CalendarInterval i2 = new CalendarInterval(3, 321); + CalendarInterval i3 = new CalendarInterval(1, 123); + CalendarInterval i4 = new CalendarInterval(3, 123); assertNotSame(i1, i2); assertNotSame(i1, i3); @@ -39,21 +39,21 @@ public void equalsTest() { @Test public void toStringTest() { - Interval i; + CalendarInterval i; - i = new Interval(34, 0); + i = new CalendarInterval(34, 0); assertEquals(i.toString(), "interval 2 years 10 months"); - i = new Interval(-34, 0); + i = new CalendarInterval(-34, 0); assertEquals(i.toString(), "interval -2 years -10 months"); - i = new Interval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + i = new CalendarInterval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); assertEquals(i.toString(), "interval 3 weeks 13 hours 123 microseconds"); - i = new Interval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); + i = new CalendarInterval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); assertEquals(i.toString(), "interval -3 weeks -13 hours -123 microseconds"); - i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); + i = new CalendarInterval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); } @@ -72,33 +72,33 @@ public void fromStringTest() { String input; input = "interval -5 years 23 month"; - Interval result = new Interval(-5 * 12 + 23, 0); - assertEquals(Interval.fromString(input), result); + CalendarInterval result = new CalendarInterval(-5 * 12 + 23, 0); + assertEquals(CalendarInterval.fromString(input), result); input = "interval -5 years 23 month "; - assertEquals(Interval.fromString(input), result); + assertEquals(CalendarInterval.fromString(input), result); input = " interval -5 years 23 month "; - assertEquals(Interval.fromString(input), result); + assertEquals(CalendarInterval.fromString(input), result); // Error cases input = "interval 3month 1 hour"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = "interval 3 moth 1 hour"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = "interval"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = "int"; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = ""; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); input = null; - assertEquals(Interval.fromString(input), null); + assertEquals(CalendarInterval.fromString(input), null); } @Test @@ -106,18 +106,18 @@ public void addTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - Interval interval = Interval.fromString(input); - Interval interval2 = Interval.fromString(input2); + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR)); + assertEquals(interval.add(interval2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = Interval.fromString(input); - interval2 = Interval.fromString(input2); + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR)); + assertEquals(interval.add(interval2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); } @Test @@ -125,25 +125,25 @@ public void subtractTest() { String input = "interval 3 month 1 hour"; String input2 = "interval 2 month 100 hour"; - Interval interval = Interval.fromString(input); - Interval interval2 = Interval.fromString(input2); + CalendarInterval interval = CalendarInterval.fromString(input); + CalendarInterval interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR)); + assertEquals(interval.subtract(interval2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); input = "interval -10 month -81 hour"; input2 = "interval 75 month 200 hour"; - interval = Interval.fromString(input); - interval2 = Interval.fromString(input2); + interval = CalendarInterval.fromString(input); + interval2 = CalendarInterval.fromString(input2); - assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR)); + assertEquals(interval.subtract(interval2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); } private void testSingleUnit(String unit, int number, int months, long microseconds) { String input1 = "interval " + number + " " + unit; String input2 = "interval " + number + " " + unit + "s"; - Interval result = new Interval(months, microseconds); - assertEquals(Interval.fromString(input1), result); - assertEquals(Interval.fromString(input2), result); + CalendarInterval result = new CalendarInterval(months, microseconds); + assertEquals(CalendarInterval.fromString(input1), result); + assertEquals(CalendarInterval.fromString(input2), result); } } From b715933fc69a49653abdb2fba0818dfc4f35d358 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Wed, 29 Jul 2015 13:59:00 -0700 Subject: [PATCH 137/219] [SPARK-9436] [GRAPHX] Pregel simplification patch Pregel code contains two consecutive joins: ``` g.vertices.innerJoin(messages)(vprog) ... g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } ``` This can be simplified with one join. ankurdave proposed a patch based on our discussion in the mailing list: https://www.mail-archive.com/devspark.apache.org/msg10316.html Author: Alexander Ulanov Closes #7749 from avulanov/SPARK-9436-pregel and squashes the following commits: 8568e06 [Alexander Ulanov] Pregel simplification patch --- .../org/apache/spark/graphx/Pregel.scala | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index cfcf7244eaed5..2ca60d51f8331 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -127,28 +127,25 @@ object Pregel extends Logging { var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { - // Receive the messages. Vertices that didn't get any messages do not appear in newVerts. - val newVerts = g.vertices.innerJoin(messages)(vprog).cache() - // Update the graph with the new vertices. + // Receive the messages and update the vertices. prevG = g - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } - g.cache() + g = g.joinVertices(messages)(vprog).cache() val oldMessages = messages - // Send new messages. Vertices that didn't get any messages don't appear in newVerts, so don't - // get to send messages. We must cache messages so it can be materialized on the next line, - // allowing us to uncache the previous iteration. - messages = g.mapReduceTriplets(sendMsg, mergeMsg, Some((newVerts, activeDirection))).cache() - // The call to count() materializes `messages`, `newVerts`, and the vertices of `g`. This - // hides oldMessages (depended on by newVerts), newVerts (depended on by messages), and the - // vertices of prevG (depended on by newVerts, oldMessages, and the vertices of g). + // Send new messages, skipping edges where neither side received a message. We must cache + // messages so it can be materialized on the next line, allowing us to uncache the previous + // iteration. + messages = g.mapReduceTriplets( + sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages + // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages + // and the vertices of g). activeMessages = messages.count() logInfo("Pregel finished iteration " + i) // Unpersist the RDDs hidden by newly-materialized RDDs oldMessages.unpersist(blocking = false) - newVerts.unpersist(blocking = false) prevG.unpersistVertices(blocking = false) prevG.edges.unpersist(blocking = false) // count the iteration From 1b0099fc62d02ff6216a76fbfe17a4ec5b2f3536 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 29 Jul 2015 16:00:30 -0700 Subject: [PATCH 138/219] [SPARK-9411] [SQL] Make Tungsten page sizes configurable We need to make page sizes configurable so we can reduce them in unit tests and increase them in real production workloads. These sizes are now controlled by a new configuration, `spark.buffer.pageSize`. The new default is 64 megabytes. Author: Josh Rosen Closes #7741 from JoshRosen/SPARK-9411 and squashes the following commits: a43c4db [Josh Rosen] Fix pow 2c0eefc [Josh Rosen] Fix MAXIMUM_PAGE_SIZE_BYTES comment + value bccfb51 [Josh Rosen] Lower page size to 4MB in TestHive ba54d4b [Josh Rosen] Make UnsafeExternalSorter's page size configurable 0045aa2 [Josh Rosen] Make UnsafeShuffle's page size configurable bc734f0 [Josh Rosen] Rename configuration e614858 [Josh Rosen] Makes BytesToBytesMap page size configurable --- .../unsafe/UnsafeShuffleExternalSorter.java | 35 +++++++++------ .../shuffle/unsafe/UnsafeShuffleWriter.java | 5 +++ .../unsafe/sort/UnsafeExternalSorter.java | 30 +++++++------ .../unsafe/UnsafeShuffleWriterSuite.java | 6 +-- .../UnsafeFixedWidthAggregationMap.java | 5 ++- .../UnsafeFixedWidthAggregationMapSuite.scala | 6 ++- .../sql/execution/GeneratedAggregate.scala | 4 +- .../sql/execution/joins/HashedRelation.scala | 7 ++- .../apache/spark/sql/hive/test/TestHive.scala | 1 + .../spark/unsafe/map/BytesToBytesMap.java | 43 ++++++++++++------- .../unsafe/memory/TaskMemoryManager.java | 13 ++++-- .../map/AbstractBytesToBytesMapSuite.java | 22 +++++----- 12 files changed, 112 insertions(+), 65 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java index 1d460432be9ff..1aa6ba4201261 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java @@ -59,14 +59,14 @@ final class UnsafeShuffleExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class); - private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES; @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - @VisibleForTesting - static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; private final int initialSize; private final int numPartitions; + private final int pageSizeBytes; + @VisibleForTesting + final int maxRecordSizeBytes; private final TaskMemoryManager memoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; @@ -109,7 +109,10 @@ public UnsafeShuffleExternalSorter( this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - + this.pageSizeBytes = (int) Math.min( + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, + conf.getSizeAsBytes("spark.buffer.pageSize", "64m")); + this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); } @@ -272,7 +275,11 @@ void spill() throws IOException { } private long getMemoryUsage() { - return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return sorter.getMemoryUsage() + totalPageSize; } private long freeMemory() { @@ -346,23 +353,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for // BytesToBytesMap. - if (requiredSpace > PAGE_SIZE) { + if (requiredSpace > pageSizeBytes) { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { shuffleMemoryManager.release(memoryAcquired); spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPage = memoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = PAGE_SIZE; + freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index 764578b181422..d47d6fc9c2ac4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -129,6 +129,11 @@ public UnsafeShuffleWriter( open(); } + @VisibleForTesting + public int maxRecordSizeBytes() { + return sorter.maxRecordSizeBytes; + } + /** * This convenience method should only be called in test code. */ diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 80b03d7e99e2b..c21990f4e4778 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -41,10 +41,7 @@ public final class UnsafeExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); - private static final int PAGE_SIZE = 1 << 27; // 128 megabytes - @VisibleForTesting - static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; - + private final long pageSizeBytes; private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; private final int initialSize; @@ -91,6 +88,7 @@ public UnsafeExternalSorter( this.initialSize = initialSize; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); initializeForWriting(); } @@ -147,7 +145,11 @@ public void spill() throws IOException { } private long getMemoryUsage() { - return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE); + long totalPageSize = 0; + for (MemoryBlock page : allocatedPages) { + totalPageSize += page.size(); + } + return sorter.getMemoryUsage() + totalPageSize; } @VisibleForTesting @@ -214,23 +216,23 @@ private void allocateSpaceForRecord(int requiredSpace) throws IOException { // TODO: we should track metrics on the amount of space wasted when we roll over to a new page // without using the free space at the end of the current page. We should also do this for // BytesToBytesMap. - if (requiredSpace > PAGE_SIZE) { + if (requiredSpace > pageSizeBytes) { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - PAGE_SIZE + ")"); + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquired < PAGE_SIZE) { + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquired < pageSizeBytes) { shuffleMemoryManager.release(memoryAcquired); spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE); - if (memoryAcquiredAfterSpilling != PAGE_SIZE) { + final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); + if (memoryAcquiredAfterSpilling != pageSizeBytes) { shuffleMemoryManager.release(memoryAcquiredAfterSpilling); - throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory"); + throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(PAGE_SIZE); + currentPage = memoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = PAGE_SIZE; + freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 10c3eedbf4b46..04fc09b323dbb 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -111,7 +111,7 @@ public void setUp() throws IOException { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); - conf = new SparkConf(); + conf = new SparkConf().set("spark.buffer.pageSize", "128m"); taskMetrics = new TaskMetrics(); when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); @@ -512,12 +512,12 @@ public void close() { } writer.insertRecordIntoSorter(new Tuple2(new byte[1], new byte[1])); writer.forceSorterToSpill(); // We should be able to write a record that's right _at_ the max record size - final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE]; + final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()]; new Random(42).nextBytes(atMaxRecordSize); writer.insertRecordIntoSorter(new Tuple2(new byte[0], atMaxRecordSize)); writer.forceSorterToSpill(); // Inserting a record that's larger than the max record size should fail: - final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1]; + final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1]; new Random(42).nextBytes(exceedsMaxRecordSize); Product2 hugeRecord = new Tuple2(new byte[0], exceedsMaxRecordSize); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 684de6e81d67c..03f4c3ed8e6bb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -95,6 +95,7 @@ public static boolean supportsAggregationBufferSchema(StructType schema) { * @param groupingKeySchema the schema of the grouping key, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). + * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( @@ -103,11 +104,13 @@ public UnsafeFixedWidthAggregationMap( StructType groupingKeySchema, TaskMemoryManager memoryManager, int initialCapacity, + long pageSizeBytes, boolean enablePerfMetrics) { this.aggregationBufferSchema = aggregationBufferSchema; this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.map = + new BytesToBytesMap(memoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); this.enablePerfMetrics = enablePerfMetrics; // Initialize the buffer for aggregation value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 48b7dc57451a3..6a907290f2dbe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -39,6 +39,7 @@ class UnsafeFixedWidthAggregationMapSuite private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) private def emptyAggregationBuffer: InternalRow = InternalRow(0) + private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes private var memoryManager: TaskMemoryManager = null @@ -69,7 +70,8 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, memoryManager, - 1024, // initial capacity + 1024, // initial capacity, + PAGE_SIZE_BYTES, false // disable perf metrics ) assert(!map.iterator().hasNext) @@ -83,6 +85,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, memoryManager, 1024, // initial capacity + PAGE_SIZE_BYTES, false // disable perf metrics ) val groupKey = InternalRow(UTF8String.fromString("cats")) @@ -109,6 +112,7 @@ class UnsafeFixedWidthAggregationMapSuite groupKeySchema, memoryManager, 128, // initial capacity + PAGE_SIZE_BYTES, false // disable perf metrics ) val rand = new Random(42) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 1cd1420480f03..b85aada9d9d4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.TaskContext +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -260,12 +260,14 @@ case class GeneratedAggregate( } else if (unsafeEnabled && schemaSupportsUnsafe) { assert(iter.hasNext, "There should be at least one row for this path") log.info("Using Unsafe-based aggregator") + val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") val aggregationMap = new UnsafeFixedWidthAggregationMap( newAggregationBuffer(EmptyRow), aggregationBufferSchema, groupKeySchema, TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity + pageSizeBytes, false // disable tracking of performance metrics ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 9c058f1f72fe4..7a507391316a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -21,6 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -259,7 +260,11 @@ private[joins] final class UnsafeHashedRelation( val nKeys = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - binaryMap = new BytesToBytesMap(memoryManager, nKeys * 2) // reduce hash collision + val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") + binaryMap = new BytesToBytesMap( + memoryManager, + nKeys * 2, // reduce hash collision + pageSizeBytes) var i = 0 var keyBuffer = new Array[Byte](1024) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 3662a4352f55d..7bbdef90cd6b9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -56,6 +56,7 @@ object TestHive .set("spark.sql.test", "") .set("spark.sql.hive.metastore.barrierPrefixes", "org.apache.spark.sql.hive.execution.PairSerDe") + .set("spark.buffer.pageSize", "4m") // SPARK-8910 .set("spark.ui.enabled", "false"))) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d0bde69cc1068..198e0684f32f8 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -74,12 +74,6 @@ public final class BytesToBytesMap { */ private long pageCursor = 0; - /** - * The size of the data pages that hold key and value data. Map entries cannot span multiple - * pages, so this limits the maximum entry size. - */ - private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes - /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since @@ -117,6 +111,12 @@ public final class BytesToBytesMap { private final double loadFactor; + /** + * The size of the data pages that hold key and value data. Map entries cannot span multiple + * pages, so this limits the maximum entry size. + */ + private final long pageSizeBytes; + /** * Number of keys defined in the map. */ @@ -153,10 +153,12 @@ public BytesToBytesMap( TaskMemoryManager memoryManager, int initialCapacity, double loadFactor, + long pageSizeBytes, boolean enablePerfMetrics) { this.memoryManager = memoryManager; this.loadFactor = loadFactor; this.loc = new Location(); + this.pageSizeBytes = pageSizeBytes; this.enablePerfMetrics = enablePerfMetrics; if (initialCapacity <= 0) { throw new IllegalArgumentException("Initial capacity must be greater than 0"); @@ -165,18 +167,26 @@ public BytesToBytesMap( throw new IllegalArgumentException( "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY); } + if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) { + throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " + + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); + } allocate(initialCapacity); } - public BytesToBytesMap(TaskMemoryManager memoryManager, int initialCapacity) { - this(memoryManager, initialCapacity, 0.70, false); + public BytesToBytesMap( + TaskMemoryManager memoryManager, + int initialCapacity, + long pageSizeBytes) { + this(memoryManager, initialCapacity, 0.70, pageSizeBytes, false); } public BytesToBytesMap( TaskMemoryManager memoryManager, int initialCapacity, + long pageSizeBytes, boolean enablePerfMetrics) { - this(memoryManager, initialCapacity, 0.70, enablePerfMetrics); + this(memoryManager, initialCapacity, 0.70, pageSizeBytes, enablePerfMetrics); } /** @@ -443,20 +453,20 @@ public void putNewKey( // must be stored in the same memory page. // (8 byte key length) (key) (8 byte value length) (value) final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes; - assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker. + assert (requiredSize <= pageSizeBytes - 8); // Reserve 8 bytes for the end-of-page marker. size++; bitset.set(pos); // If there's not enough space in the current page, allocate a new page (8 bytes are reserved // for the end-of-page marker). - if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) { + if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) { if (currentDataPage != null) { // There wasn't enough space in the current page, so write an end-of-page marker: final Object pageBaseObject = currentDataPage.getBaseObject(); final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); } - MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES); + MemoryBlock newPage = memoryManager.allocatePage(pageSizeBytes); dataPages.add(newPage); pageCursor = 0; currentDataPage = newPage; @@ -538,10 +548,11 @@ public void free() { /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ public long getTotalMemoryConsumption() { - return ( - dataPages.size() * PAGE_SIZE_BYTES + - bitset.memoryBlock().size() + - longArray.memoryBlock().size()); + long totalDataPagesSize = 0L; + for (MemoryBlock dataPage : dataPages) { + totalDataPagesSize += dataPage.size(); + } + return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size(); } /** diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java index 10881969dbc78..dd70df3b1f791 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java @@ -58,8 +58,13 @@ public class TaskMemoryManager { /** The number of entries in the page table. */ private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; - /** Maximum supported data page size */ - private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS); + /** + * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page + * size is limited by the maximum amount of data that can be stored in a long[] array, which is + * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. + */ + public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; /** Bit mask for the lower 51 bits of a long. */ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; @@ -110,9 +115,9 @@ public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { * intended for allocating large blocks of memory that will be shared between operators. */ public MemoryBlock allocatePage(long size) { - if (size > MAXIMUM_PAGE_SIZE) { + if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new IllegalArgumentException( - "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes"); + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } final int pageNumber; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index dae47e4bab0cb..0be94ad371255 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -43,6 +43,7 @@ public abstract class AbstractBytesToBytesMapSuite { private TaskMemoryManager memoryManager; private TaskMemoryManager sizeLimitedMemoryManager; + private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes @Before public void setup() { @@ -110,7 +111,7 @@ private static boolean arrayEquals( @Test public void emptyMap() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES); try { Assert.assertEquals(0, map.size()); final int keyLengthInWords = 10; @@ -125,7 +126,7 @@ public void emptyMap() { @Test public void setAndRetrieveAKey() { - BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64); + BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES); final int recordLengthWords = 10; final int recordLengthBytes = recordLengthWords * 8; final byte[] keyData = getRandomByteArray(recordLengthWords); @@ -177,7 +178,7 @@ public void setAndRetrieveAKey() { @Test public void iteratorTest() throws Exception { final int size = 4096; - BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2); + BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2, PAGE_SIZE_BYTES); try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; @@ -235,7 +236,7 @@ public void iteratingOverDataPagesWithWastedSpace() throws Exception { final int NUM_ENTRIES = 1000 * 1000; final int KEY_LENGTH = 16; final int VALUE_LENGTH = 40; - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte // pages won't be evenly-divisible by records of this size, which will cause us to waste some // space at the end of the page. This is necessary in order for us to take the end-of-record @@ -304,7 +305,7 @@ public void randomizedStressTest() { // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); - final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size); + final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size, PAGE_SIZE_BYTES); try { // Fill the map to 90% full so that we can trigger probing @@ -353,14 +354,15 @@ public void randomizedStressTest() { @Test public void initialCapacityBoundsChecking() { try { - new BytesToBytesMap(sizeLimitedMemoryManager, 0); + new BytesToBytesMap(sizeLimitedMemoryManager, 0, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception } try { - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1); + new BytesToBytesMap( + sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception @@ -368,15 +370,15 @@ public void initialCapacityBoundsChecking() { // Can allocate _at_ the max capacity BytesToBytesMap map = - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY); + new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY, PAGE_SIZE_BYTES); map.free(); } @Test public void resizingLargeMap() { // As long as a map's capacity is below the max, we should be able to resize up to the max - BytesToBytesMap map = - new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64); + BytesToBytesMap map = new BytesToBytesMap( + sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64, PAGE_SIZE_BYTES); map.growAndRehash(); map.free(); } From 2cc212d56a1d50fe68d5816f71b27803de1f6389 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 29 Jul 2015 16:20:20 -0700 Subject: [PATCH 139/219] [SPARK-6793] [MLLIB] OnlineLDAOptimizer LDA perplexity Implements `logPerplexity` in `OnlineLDAOptimizer`. Also refactors inference code into companion object to enable future reuse (e.g. `predict` method). Author: Feynman Liang Closes #7705 from feynmanliang/SPARK-6793-perplexity and squashes the following commits: 6da2c99 [Feynman Liang] Remove get* from LDAModel public API 8381da6 [Feynman Liang] Code review comments 17f7000 [Feynman Liang] Documentation typo fixes 2f452a4 [Feynman Liang] Remove auxillary DistributedLDAModel constructor a275914 [Feynman Liang] Prevent empty counts calls to variationalInference 06d02d9 [Feynman Liang] Remove deprecated LocalLDAModel constructor afecb46 [Feynman Liang] Fix regression bug in sstats accumulator 5a327a0 [Feynman Liang] Code review quick fixes 998c03e [Feynman Liang] Fix style 1cbb67d [Feynman Liang] Fix access modifier bug 4362daa [Feynman Liang] Organize imports 4f171f7 [Feynman Liang] Fix indendation 2f049ce [Feynman Liang] Fix failing save/load tests 7415e96 [Feynman Liang] Pick changes from big PR 11e7c33 [Feynman Liang] Merge remote-tracking branch 'apache/master' into SPARK-6793-perplexity f8adc48 [Feynman Liang] Add logPerplexity, refactor variationalBound into a method cd521d6 [Feynman Liang] Refactor methods into companion class 7f62a55 [Feynman Liang] --amend c62cb1e [Feynman Liang] Outer product for stats, revert Range slicing aead650 [Feynman Liang] Range slice, in-place update, reduce transposes --- .../spark/mllib/clustering/LDAModel.scala | 200 ++++++++++++++---- .../spark/mllib/clustering/LDAOptimizer.scala | 138 +++++++----- .../spark/mllib/clustering/LDAUtils.scala | 55 +++++ .../spark/mllib/clustering/JavaLDASuite.java | 6 +- .../spark/mllib/clustering/LDASuite.scala | 53 ++++- 5 files changed, 348 insertions(+), 104 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 31c1d520fd659..059b52ef20a98 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV} - +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -28,14 +27,13 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector} -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.graphx.{Edge, EdgeContext, Graph, VertexId} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.BoundedPriorityQueue - /** * :: Experimental :: * @@ -53,6 +51,31 @@ abstract class LDAModel private[clustering] extends Saveable { /** Vocabulary size (number of terms or terms in the vocabulary) */ def vocabSize: Int + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a Dirichlet distribution. + */ + def docConcentration: Vector + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + */ + def topicConcentration: Double + + /** + * Shape parameter for random initialization of variational parameter gamma. + * Used for variational inference for perplexity and other test-time computations. + */ + protected def gammaShape: Double + /** * Inferred topics, where each topic is represented by a distribution over terms. * This is a matrix of size vocabSize x k, where each column is a topic. @@ -168,7 +191,10 @@ abstract class LDAModel private[clustering] extends Saveable { */ @Experimental class LocalLDAModel private[clustering] ( - private val topics: Matrix) extends LDAModel with Serializable { + val topics: Matrix, + override val docConcentration: Vector, + override val topicConcentration: Double, + override protected[clustering] val gammaShape: Double) extends LDAModel with Serializable { override def k: Int = topics.numCols @@ -197,8 +223,82 @@ class LocalLDAModel private[clustering] ( // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + /** + * Calculate the log variational bound on perplexity. See Equation (16) in original Online + * LDA paper. + * @param documents test corpus to use for calculating perplexity + * @return the log perplexity per word + */ + def logPerplexity(documents: RDD[(Long, Vector)]): Double = { + val corpusWords = documents + .map { case (_, termCounts) => termCounts.toArray.sum } + .sum() + val batchVariationalBound = bound(documents, docConcentration, + topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) + val perWordBound = batchVariationalBound / corpusWords + + perWordBound + } + + /** + * Estimate the variational likelihood bound of from `documents`: + * log p(documents) >= E_q[log p(documents)] - E_q[log q(documents)] + * This bound is derived by decomposing the LDA model to: + * log p(documents) = E_q[log p(documents)] - E_q[log q(documents)] + D(q|p) + * and noting that the KL-divergence D(q|p) >= 0. See Equation (16) in original Online LDA paper. + * @param documents a subset of the test corpus + * @param alpha document-topic Dirichlet prior parameters + * @param eta topic-word Dirichlet prior parameters + * @param lambda parameters for variational q(beta | lambda) topic-word distributions + * @param gammaShape shape parameter for random initialization of variational q(theta | gamma) + * topic mixture distributions + * @param k number of topics + * @param vocabSize number of unique terms in the entire test corpus + */ + private def bound( + documents: RDD[(Long, Vector)], + alpha: Vector, + eta: Double, + lambda: BDM[Double], + gammaShape: Double, + k: Int, + vocabSize: Long): Double = { + val brzAlpha = alpha.toBreeze.toDenseVector + // transpose because dirichletExpectation normalizes by row and we need to normalize + // by topic (columns of lambda) + val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t + + var score = documents.filter(_._2.numActives > 0).map { case (id: Long, termCounts: Vector) => + var docScore = 0.0D + val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, exp(Elogbeta), brzAlpha, gammaShape, k) + val Elogthetad: BDV[Double] = LDAUtils.dirichletExpectation(gammad) + + // E[log p(doc | theta, beta)] + termCounts.foreachActive { case (idx, count) => + docScore += LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t) + } + // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector + docScore += sum((brzAlpha - gammad) :* Elogthetad) + docScore += sum(lgamma(gammad) - lgamma(brzAlpha)) + docScore += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) + + docScore + }.sum() + + // E[log p(beta | eta) - log q (beta | lambda)]; assumes eta is a scalar + score += sum((eta - lambda) :* Elogbeta) + score += sum(lgamma(lambda) - lgamma(eta)) + + val sumEta = eta * vocabSize + score += sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) + + score + } + } + @Experimental object LocalLDAModel extends Loader[LocalLDAModel] { @@ -212,6 +312,8 @@ object LocalLDAModel extends Loader[LocalLDAModel] { // as a Row in data. case class Data(topic: Vector, index: Int) + // TODO: explicitly save docConcentration, topicConcentration, and gammaShape for use in + // model.predict() def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ @@ -219,7 +321,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ - ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows))) + ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix @@ -243,7 +345,11 @@ object LocalLDAModel extends Loader[LocalLDAModel] { topics.foreach { case Row(vec: Vector, ind: Int) => brzTopics(::, ind) := vec.toBreeze } - new LocalLDAModel(Matrices.fromBreeze(brzTopics)) + val topicsMat = Matrices.fromBreeze(brzTopics) + + // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940 + new LocalLDAModel(topicsMat, + Vectors.dense(Array.fill(topicsMat.numRows)(1.0 / topicsMat.numRows)), 1D, 100D) } } @@ -259,8 +365,8 @@ object LocalLDAModel extends Loader[LocalLDAModel] { SaveLoadV1_0.load(sc, path) case _ => throw new Exception( s"LocalLDAModel.load did not recognize model with (className, format version):" + - s"($loadedClassName, $loadedVersion). Supported:\n" + - s" ($classNameV1_0, 1.0)") + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") } val topicsMatrix = model.topicsMatrix @@ -268,7 +374,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics") require(expectedVocabSize == topicsMatrix.numRows, s"LocalLDAModel requires $expectedVocabSize terms for each topic, " + - s"but got ${topicsMatrix.numRows}") + s"but got ${topicsMatrix.numRows}") model } } @@ -282,28 +388,25 @@ object LocalLDAModel extends Loader[LocalLDAModel] { * than the [[LocalLDAModel]]. */ @Experimental -class DistributedLDAModel private ( +class DistributedLDAModel private[clustering] ( private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount], private[clustering] val globalTopicTotals: LDA.TopicCounts, val k: Int, val vocabSize: Int, - private[clustering] val docConcentration: Double, - private[clustering] val topicConcentration: Double, + override val docConcentration: Vector, + override val topicConcentration: Double, + override protected[clustering] val gammaShape: Double, private[spark] val iterationTimes: Array[Double]) extends LDAModel { import LDA._ - private[clustering] def this(state: EMLDAOptimizer, iterationTimes: Array[Double]) = { - this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, - state.topicConcentration, iterationTimes) - } - /** * Convert model to a local model. * The local model stores the inferred topics but not the topic distributions for training * documents. */ - def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix) + def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix, docConcentration, topicConcentration, + gammaShape) /** * Inferred topics, where each topic is represented by a distribution over terms. @@ -375,8 +478,9 @@ class DistributedLDAModel private ( * hyperparameters. */ lazy val logLikelihood: Double = { - val eta = topicConcentration - val alpha = docConcentration + // TODO: generalize this for asymmetric (non-scalar) alpha + val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object + val eta = this.topicConcentration assert(eta > 1.0) assert(alpha > 1.0) val N_k = globalTopicTotals @@ -400,8 +504,9 @@ class DistributedLDAModel private ( * log P(topics, topic distributions for docs | alpha, eta) */ lazy val logPrior: Double = { - val eta = topicConcentration - val alpha = docConcentration + // TODO: generalize this for asymmetric (non-scalar) alpha + val alpha = this.docConcentration(0) // To avoid closure capture of enclosing object + val eta = this.topicConcentration // Term vertices: Compute phi_{wk}. Use to compute prior log probability. // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. val N_k = globalTopicTotals @@ -412,12 +517,12 @@ class DistributedLDAModel private ( val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k - (eta - 1.0) * brzSum(phi_wk.map(math.log)) + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) - (alpha - 1.0) * brzSum(theta_kj.map(math.log)) + (alpha - 1.0) * sum(theta_kj.map(math.log)) } } graph.vertices.aggregate(0.0)(seqOp, _ + _) @@ -448,7 +553,7 @@ class DistributedLDAModel private ( override def save(sc: SparkContext, path: String): Unit = { DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, - iterationTimes) + iterationTimes, gammaShape) } } @@ -478,17 +583,20 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { globalTopicTotals: LDA.TopicCounts, k: Int, vocabSize: Int, - docConcentration: Double, + docConcentration: Vector, topicConcentration: Double, - iterationTimes: Array[Double]): Unit = { + iterationTimes: Array[Double], + gammaShape: Double): Unit = { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val metadata = compact(render (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~ - ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~ - ("topicConcentration" -> topicConcentration) ~ - ("iterationTimes" -> iterationTimes.toSeq))) + ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("iterationTimes" -> iterationTimes.toSeq) ~ + ("gammaShape" -> gammaShape))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString @@ -510,9 +618,10 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { sc: SparkContext, path: String, vocabSize: Int, - docConcentration: Double, + docConcentration: Vector, topicConcentration: Double, - iterationTimes: Array[Double]): DistributedLDAModel = { + iterationTimes: Array[Double], + gammaShape: Double): DistributedLDAModel = { val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString @@ -536,7 +645,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, - docConcentration, topicConcentration, iterationTimes) + docConcentration, topicConcentration, gammaShape, iterationTimes) } } @@ -546,32 +655,35 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] val vocabSize = (metadata \ "vocabSize").extract[Int] - val docConcentration = (metadata \ "docConcentration").extract[Double] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) val topicConcentration = (metadata \ "topicConcentration").extract[Double] val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]] + val gammaShape = (metadata \ "gammaShape").extract[Double] val classNameV1_0 = SaveLoadV1_0.classNameV1_0 val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => { - DistributedLDAModel.SaveLoadV1_0.load( - sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray) + DistributedLDAModel.SaveLoadV1_0.load(sc, path, vocabSize, docConcentration, + topicConcentration, iterationTimes.toArray, gammaShape) } case _ => throw new Exception( s"DistributedLDAModel.load did not recognize model with (className, format version):" + - s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") } require(model.vocabSize == vocabSize, s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize") require(model.docConcentration == docConcentration, s"DistributedLDAModel requires $docConcentration docConcentration, " + - s"got ${model.docConcentration} docConcentration") + s"got ${model.docConcentration} docConcentration") require(model.topicConcentration == topicConcentration, s"DistributedLDAModel requires $topicConcentration docConcentration, " + - s"got ${model.topicConcentration} docConcentration") + s"got ${model.topicConcentration} docConcentration") require(expectedK == model.k, s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics") model } } + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index f4170a3d98dd8..7e75e7083acb5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import java.util.Random import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} -import breeze.numerics.{abs, digamma, exp} +import breeze.numerics.{abs, exp} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.DeveloperApi @@ -208,7 +208,11 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") this.graphCheckpointer.deleteAllCheckpoints() - new DistributedLDAModel(this, iterationTimes) + // This assumes gammaShape = 100 in OnlineLDAOptimizer to ensure equivalence in LDAModel.toLocal + // conversion + new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, + Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, + 100, iterationTimes) } } @@ -385,71 +389,52 @@ final class OnlineLDAOptimizer extends LDAOptimizer { iteration += 1 val k = this.k val vocabSize = this.vocabSize - val Elogbeta = dirichletExpectation(lambda).t - val expElogbeta = exp(Elogbeta) + val expElogbeta = exp(LDAUtils.dirichletExpectation(lambda)).t val alpha = this.alpha.toBreeze val gammaShape = this.gammaShape - val stats: RDD[BDM[Double]] = batch.mapPartitions { docs => - val stat = BDM.zeros[Double](k, vocabSize) - docs.foreach { doc => - val termCounts = doc._2 - val (ids: List[Int], cts: Array[Double]) = termCounts match { - case v: DenseVector => ((0 until v.size).toList, v.values) - case v: SparseVector => (v.indices.toList, v.values) - case v => throw new IllegalArgumentException("Online LDA does not support vector type " - + v.getClass) - } - if (!ids.isEmpty) { - - // Initialize the variational distribution q(theta|gamma) for the mini-batch - val gammad: BDV[Double] = - new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K - val expElogthetad: BDV[Double] = exp(digamma(gammad) - digamma(sum(gammad))) // K - val expElogbetad: BDM[Double] = expElogbeta(ids, ::).toDenseMatrix // ids * K - - val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids - var meanchange = 1D - val ctsVector = new BDV[Double](cts) // ids - - // Iterate between gamma and phi until convergence - while (meanchange > 1e-3) { - val lastgamma = gammad.copy - // K K * ids ids - gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha - expElogthetad := exp(digamma(gammad) - digamma(sum(gammad))) - phinorm := expElogbetad * expElogthetad :+ 1e-100 - meanchange = sum(abs(gammad - lastgamma)) / k - } + val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val nonEmptyDocs = docs.filter(_._2.numActives > 0) - stat(::, ids) := expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix + val stat = BDM.zeros[Double](k, vocabSize) + var gammaPart = List[BDV[Double]]() + nonEmptyDocs.zipWithIndex.foreach { case ((_, termCounts: Vector), idx: Int) => + val ids: List[Int] = termCounts match { + case v: DenseVector => (0 until v.size).toList + case v: SparseVector => v.indices.toList } + val (gammad, sstats) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, expElogbeta, alpha, gammaShape, k) + stat(::, ids) := stat(::, ids).toDenseMatrix + sstats + gammaPart = gammad :: gammaPart } - Iterator(stat) + Iterator((stat, gammaPart)) } - - val statsSum: BDM[Double] = stats.reduce(_ += _) + val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _) + val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( + stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*) val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count - update(batchResult, iteration, (miniBatchFraction * corpusSize).ceil.toInt) + updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) this } - override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { - new LocalLDAModel(Matrices.fromBreeze(lambda).transpose) - } - /** * Update lambda based on the batch submitted. batchSize can be different for each iteration. */ - private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = { + private def updateLambda(stat: BDM[Double], batchSize: Int): Unit = { // weight of the mini-batch. - val weight = math.pow(getTau0 + iter, -getKappa) + val weight = rho() // Update lambda based on documents. - lambda = lambda * (1 - weight) + - (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) * weight + lambda := (1 - weight) * lambda + + weight * (stat * (corpusSize.toDouble / batchSize.toDouble) + eta) + } + + /** Calculates learning rate rho, which decays as a function of [[iteration]] */ + private def rho(): Double = { + math.pow(getTau0 + this.iteration, -getKappa) } /** @@ -463,15 +448,56 @@ final class OnlineLDAOptimizer extends LDAOptimizer { new BDM[Double](col, row, temp).t } + override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { + new LocalLDAModel(Matrices.fromBreeze(lambda).transpose, alpha, eta, gammaShape) + } + +} + +/** + * Serializable companion object containing helper methods and shared code for + * [[OnlineLDAOptimizer]] and [[LocalLDAModel]]. + */ +private[clustering] object OnlineLDAOptimizer { /** - * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation - * uses digamma which is accurate but expensive. + * Uses variational inference to infer the topic distribution `gammad` given the term counts + * for a document. `termCounts` must be non-empty, otherwise Breeze will throw a BLAS error. + * + * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001) + * avoids explicit computation of variational parameter `phi`. + * @see [[http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.31.7566]] */ - private def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { - val rowSum = sum(alpha(breeze.linalg.*, ::)) - val digAlpha = digamma(alpha) - val digRowSum = digamma(rowSum) - val result = digAlpha(::, breeze.linalg.*) - digRowSum - result + private[clustering] def variationalTopicInference( + termCounts: Vector, + expElogbeta: BDM[Double], + alpha: breeze.linalg.Vector[Double], + gammaShape: Double, + k: Int): (BDV[Double], BDM[Double]) = { + val (ids: List[Int], cts: Array[Double]) = termCounts match { + case v: DenseVector => ((0 until v.size).toList, v.values) + case v: SparseVector => (v.indices.toList, v.values) + } + // Initialize the variational distribution q(theta|gamma) for the mini-batch + val gammad: BDV[Double] = + new Gamma(gammaShape, 1.0 / gammaShape).samplesVector(k) // K + val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K + val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K + + val phinorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + var meanchange = 1D + val ctsVector = new BDV[Double](cts) // ids + + // Iterate between gamma and phi until convergence + while (meanchange > 1e-3) { + val lastgamma = gammad.copy + // K K * ids ids + gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phinorm))) :+ alpha + expElogthetad := exp(LDAUtils.dirichletExpectation(gammad)) + phinorm := expElogbetad * expElogthetad :+ 1e-100 + meanchange = sum(abs(gammad - lastgamma)) / k + } + + val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phinorm).asDenseMatrix + (gammad, sstatsd) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala new file mode 100644 index 0000000000000..f7e5ce1665fe6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum} +import breeze.numerics._ + +/** + * Utility methods for LDA. + */ +object LDAUtils { + /** + * Log Sum Exp with overflow protection using the identity: + * For any a: \log \sum_{n=1}^N \exp\{x_n\} = a + \log \sum_{n=1}^N \exp\{x_n - a\} + */ + private[clustering] def logSumExp(x: BDV[Double]): Double = { + val a = max(x) + a + log(sum(exp(x :- a))) + } + + /** + * For theta ~ Dir(alpha), computes E[log(theta)] given alpha. Currently the implementation + * uses [[breeze.numerics.digamma]] which is accurate but expensive. + */ + private[clustering] def dirichletExpectation(alpha: BDV[Double]): BDV[Double] = { + digamma(alpha) - digamma(sum(alpha)) + } + + /** + * Computes [[dirichletExpectation()]] row-wise, assuming each row of alpha are + * Dirichlet parameters. + */ + private[clustering] def dirichletExpectation(alpha: BDM[Double]): BDM[Double] = { + val rowSum = sum(alpha(breeze.linalg.*, ::)) + val digAlpha = digamma(alpha) + val digRowSum = digamma(rowSum) + val result = digAlpha(::, breeze.linalg.*) - digRowSum + result + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java index b48f190f599a2..d272a42c8576f 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; import scala.Tuple2; @@ -59,7 +60,10 @@ public void tearDown() { @Test public void localLDAModel() { - LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics()); + Matrix topics = LDASuite$.MODULE$.tinyTopics(); + double[] topicConcentration = new double[topics.numRows()]; + Arrays.fill(topicConcentration, 1.0D / topics.numRows()); + LocalLDAModel model = new LocalLDAModel(topics, Vectors.dense(topicConcentration), 1D, 100D); // Check: basic parameters assertEquals(model.k(), tinyK); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 376a87f0511b4..aa36336ebbee6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM} +import breeze.linalg.{DenseMatrix => BDM, max, argmax} import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.Edge @@ -31,7 +31,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { import LDASuite._ test("LocalLDAModel") { - val model = new LocalLDAModel(tinyTopics) + val model = new LocalLDAModel(tinyTopics, + Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D) // Check: basic parameters assert(model.k === tinyK) @@ -235,6 +236,51 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("LocalLDAModel logPerplexity") { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + + def toydata: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + val docs = sc.parallelize(toydata) + + + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(lda.log_perplexity(corpus)) + > -3.69051285096 + */ + + assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D) + } + test("OnlineLDAOptimizer with asymmetric prior") { def toydata: Array[(Long, Vector)] = Array( Vectors.sparse(6, Array(0, 1), Array(1, 1)), @@ -287,7 +333,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("model save/load") { // Test for LocalLDAModel. - val localModel = new LocalLDAModel(tinyTopics) + val localModel = new LocalLDAModel(tinyTopics, + Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D) val tempDir1 = Utils.createTempDir() val path1 = tempDir1.toURI.toString From 86505962e6c9da1ee18c6a3533e169a22e4f1665 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 16:49:02 -0700 Subject: [PATCH 140/219] [SPARK-9448][SQL] GenerateUnsafeProjection should not share expressions across instances. We accidentally moved the list of expressions from the generated code instance to the class wrapper, and as a result, different threads are sharing the same set of expressions, which cause problems for expressions with mutable state. This pull request fixed that problem, and also added unit tests for all codegen classes, except GeneratedOrdering (which will never need any expressions since sort now only accepts bound references. Author: Reynold Xin Closes #7759 from rxin/SPARK-9448 and squashes the following commits: c09b50f [Reynold Xin] [SPARK-9448][SQL] GenerateUnsafeProjection should not share expressions across instances. --- .../codegen/GenerateUnsafeProjection.scala | 12 +-- .../CodegenExpressionCachingSuite.scala | 90 +++++++++++++++++++ 2 files changed, 96 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index dc725c28aaa27..7be60114ce674 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -256,18 +256,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro eval.code = createCode(ctx, eval, expressions) val code = s""" - private $exprType[] expressions; - - public Object generate($exprType[] expr) { - this.expressions = expr; - return new SpecificProjection(); + public Object generate($exprType[] exprs) { + return new SpecificProjection(exprs); } class SpecificProjection extends ${classOf[UnsafeProjection].getName} { + private $exprType[] expressions; + ${declareMutableStates(ctx)} - public SpecificProjection() { + public SpecificProjection($exprType[] expressions) { + this.expressions = expressions; ${initMutableStates(ctx)} } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala new file mode 100644 index 0000000000000..866bf904e4a4c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, LeafExpression} +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * A test suite that makes sure code generation handles expression internally states correctly. + */ +class CodegenExpressionCachingSuite extends SparkFunSuite { + + test("GenerateUnsafeProjection") { + val expr1 = MutableExpression() + val instance1 = UnsafeProjection.create(Seq(expr1)) + assert(instance1.apply(null).getBoolean(0) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = UnsafeProjection.create(Seq(expr2)) + assert(instance1.apply(null).getBoolean(0) === false) + assert(instance2.apply(null).getBoolean(0) === true) + } + + test("GenerateProjection") { + val expr1 = MutableExpression() + val instance1 = GenerateProjection.generate(Seq(expr1)) + assert(instance1.apply(null).getBoolean(0) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = GenerateProjection.generate(Seq(expr2)) + assert(instance1.apply(null).getBoolean(0) === false) + assert(instance2.apply(null).getBoolean(0) === true) + } + + test("GenerateMutableProjection") { + val expr1 = MutableExpression() + val instance1 = GenerateMutableProjection.generate(Seq(expr1))() + assert(instance1.apply(null).getBoolean(0) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = GenerateMutableProjection.generate(Seq(expr2))() + assert(instance1.apply(null).getBoolean(0) === false) + assert(instance2.apply(null).getBoolean(0) === true) + } + + test("GeneratePredicate") { + val expr1 = MutableExpression() + val instance1 = GeneratePredicate.generate(expr1) + assert(instance1.apply(null) === false) + + val expr2 = MutableExpression() + expr2.mutableState = true + val instance2 = GeneratePredicate.generate(expr2) + assert(instance1.apply(null) === false) + assert(instance2.apply(null) === true) + } + +} + + +/** + * An expression with mutable state so we can change it freely in our test suite. + */ +case class MutableExpression() extends LeafExpression with CodegenFallback { + var mutableState: Boolean = false + override def eval(input: InternalRow): Any = mutableState + + override def nullable: Boolean = false + override def dataType: DataType = BooleanType +} From 103d8cce78533b38b4f8060b30f7f455113bc6b5 Mon Sep 17 00:00:00 2001 From: Bimal Tandel Date: Wed, 29 Jul 2015 16:54:58 -0700 Subject: [PATCH 141/219] [SPARK-8921] [MLLIB] Add @since tags to mllib.stat Author: Bimal Tandel Closes #7730 from BimalTandel/branch_spark_8921 and squashes the following commits: 3ea230a [Bimal Tandel] Spark 8921 add @since tags --- .../spark/mllib/stat/KernelDensity.scala | 5 ++++ .../stat/MultivariateOnlineSummarizer.scala | 27 +++++++++++++++++++ .../stat/MultivariateStatisticalSummary.scala | 9 +++++++ .../apache/spark/mllib/stat/Statistics.scala | 20 ++++++++++++-- .../distribution/MultivariateGaussian.scala | 9 +++++-- 5 files changed, 66 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala index 58a50f9c19f14..93a6753efd4d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/KernelDensity.scala @@ -37,6 +37,7 @@ import org.apache.spark.rdd.RDD * .setBandwidth(3.0) * val densities = kd.estimate(Array(-1.0, 2.0, 5.0)) * }}} + * @since 1.4.0 */ @Experimental class KernelDensity extends Serializable { @@ -51,6 +52,7 @@ class KernelDensity extends Serializable { /** * Sets the bandwidth (standard deviation) of the Gaussian kernel (default: `1.0`). + * @since 1.4.0 */ def setBandwidth(bandwidth: Double): this.type = { require(bandwidth > 0, s"Bandwidth must be positive, but got $bandwidth.") @@ -60,6 +62,7 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation. + * @since 1.4.0 */ def setSample(sample: RDD[Double]): this.type = { this.sample = sample @@ -68,6 +71,7 @@ class KernelDensity extends Serializable { /** * Sets the sample to use for density estimation (for Java users). + * @since 1.4.0 */ def setSample(sample: JavaRDD[java.lang.Double]): this.type = { this.sample = sample.rdd.asInstanceOf[RDD[Double]] @@ -76,6 +80,7 @@ class KernelDensity extends Serializable { /** * Estimates probability density function at the given array of points. + * @since 1.4.0 */ def estimate(points: Array[Double]): Array[Double] = { val sample = this.sample diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index d321cc554c1cc..62da9f2ef22a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -33,6 +33,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * Reference: [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]] * Zero elements (including explicit zero values) are skipped when calling add(), * to have time complexity O(nnz) instead of O(n) for each column. + * @since 1.1.0 */ @DeveloperApi class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { @@ -52,6 +53,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param sample The sample in dense/sparse vector format to be added into this summarizer. * @return This MultivariateOnlineSummarizer object. + * @since 1.1.0 */ def add(sample: Vector): this.type = { if (n == 0) { @@ -107,6 +109,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * * @param other The other MultivariateOnlineSummarizer to be merged. * @return This MultivariateOnlineSummarizer object. + * @since 1.1.0 */ def merge(other: MultivariateOnlineSummarizer): this.type = { if (this.totalCnt != 0 && other.totalCnt != 0) { @@ -149,6 +152,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this } + /** + * @since 1.1.0 + */ override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -161,6 +167,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realMean) } + /** + * @since 1.1.0 + */ override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -183,14 +192,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realVariance) } + /** + * @since 1.1.0 + */ override def count: Long = totalCnt + /** + * @since 1.1.0 + */ override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") Vectors.dense(nnz) } + /** + * @since 1.1.0 + */ override def max: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -202,6 +220,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMax) } + /** + * @since 1.1.0 + */ override def min: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -213,6 +234,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(currMin) } + /** + * @since 1.2.0 + */ override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") @@ -227,6 +251,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S Vectors.dense(realMagnitude) } + /** + * @since 1.2.0 + */ override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index 6a364c93284af..3bb49f12289e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -21,46 +21,55 @@ import org.apache.spark.mllib.linalg.Vector /** * Trait for multivariate statistical summary of a data matrix. + * @since 1.0.0 */ trait MultivariateStatisticalSummary { /** * Sample mean vector. + * @since 1.0.0 */ def mean: Vector /** * Sample variance vector. Should return a zero vector if the sample size is 1. + * @since 1.0.0 */ def variance: Vector /** * Sample size. + * @since 1.0.0 */ def count: Long /** * Number of nonzero elements (including explicitly presented zero values) in each column. + * @since 1.0.0 */ def numNonzeros: Vector /** * Maximum value of each column. + * @since 1.0.0 */ def max: Vector /** * Minimum value of each column. + * @since 1.0.0 */ def min: Vector /** * Euclidean magnitude of each column + * @since 1.2.0 */ def normL2: Vector /** * L1 norm of each column + * @since 1.2.0 */ def normL1: Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index 90332028cfb3a..f84502919e381 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -32,6 +32,7 @@ import org.apache.spark.rdd.RDD /** * :: Experimental :: * API for statistical functions in MLlib. + * @since 1.1.0 */ @Experimental object Statistics { @@ -41,6 +42,7 @@ object Statistics { * * @param X an RDD[Vector] for which column-wise summary statistics are to be computed. * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics. + * @since 1.1.0 */ def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = { new RowMatrix(X).computeColumnSummaryStatistics() @@ -52,6 +54,7 @@ object Statistics { * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @return Pearson correlation matrix comparing columns in X. + * @since 1.1.0 */ def corr(X: RDD[Vector]): Matrix = Correlations.corrMatrix(X) @@ -68,6 +71,7 @@ object Statistics { * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. + * @since 1.1.0 */ def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) @@ -81,10 +85,14 @@ object Statistics { * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s + * @since 1.1.0 */ def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) - /** Java-friendly version of [[corr()]] */ + /** + * Java-friendly version of [[corr()]] + * @since 1.4.1 + */ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double]): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]]) @@ -101,10 +109,14 @@ object Statistics { * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. + * @since 1.1.0 */ def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) - /** Java-friendly version of [[corr()]] */ + /** + * Java-friendly version of [[corr()]] + * @since 1.4.1 + */ def corr(x: JavaRDD[java.lang.Double], y: JavaRDD[java.lang.Double], method: String): Double = corr(x.rdd.asInstanceOf[RDD[Double]], y.rdd.asInstanceOf[RDD[Double]], method) @@ -121,6 +133,7 @@ object Statistics { * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { ChiSqTest.chiSquared(observed, expected) @@ -135,6 +148,7 @@ object Statistics { * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) @@ -145,6 +159,7 @@ object Statistics { * @param observed The contingency matrix (containing either counts or relative frequencies). * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * @since 1.1.0 */ def chiSqTest(observed: Matrix): ChiSqTestResult = ChiSqTest.chiSquaredMatrix(observed) @@ -157,6 +172,7 @@ object Statistics { * Real-valued features will be treated as categorical for each distinct value. * @return an array containing the ChiSquaredTestResult for every feature against the label. * The order of the elements in the returned array reflects the order of input features. + * @since 1.1.0 */ def chiSqTest(data: RDD[LabeledPoint]): Array[ChiSqTestResult] = { ChiSqTest.chiSquaredFeatures(data) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index cf51b24ff777f..9aa7763d7890d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -32,6 +32,7 @@ import org.apache.spark.mllib.util.MLUtils * * @param mu The mean vector of the distribution * @param sigma The covariance matrix of the distribution + * @since 1.3.0 */ @DeveloperApi class MultivariateGaussian ( @@ -60,12 +61,16 @@ class MultivariateGaussian ( */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - /** Returns density of this multivariate Gaussian at given point, x */ + /** Returns density of this multivariate Gaussian at given point, x + * @since 1.3.0 + */ def pdf(x: Vector): Double = { pdf(x.toBreeze) } - /** Returns the log-density of this multivariate Gaussian at given point, x */ + /** Returns the log-density of this multivariate Gaussian at given point, x + * @since 1.3.0 + */ def logpdf(x: Vector): Double = { logpdf(x.toBreeze) } From 37c2d1927cebdd19a14c054f670cb0fb9a263586 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 29 Jul 2015 18:18:29 -0700 Subject: [PATCH 142/219] [SPARK-9016] [ML] make random forest classifiers implement classification trait Implement the classification trait for RandomForestClassifiers. The plan is to use this in the future to providing thresholding for RandomForestClassifiers (as well as other classifiers that implement that trait). Author: Holden Karau Closes #7432 from holdenk/SPARK-9016-make-random-forest-classifiers-implement-classification-trait and squashes the following commits: bf22fa6 [Holden Karau] Add missing imports for testing suite e948f0d [Holden Karau] Check the prediction generation from rawprediciton 25320c3 [Holden Karau] Don't supply numClasses when not needed, assert model classes are as expected 1a67e04 [Holden Karau] Use old decission tree stuff instead 673e0c3 [Holden Karau] Merge branch 'master' into SPARK-9016-make-random-forest-classifiers-implement-classification-trait 0d15b96 [Holden Karau] FIx typo 5eafad4 [Holden Karau] add a constructor for rootnode + num classes fc6156f [Holden Karau] scala style fix 2597915 [Holden Karau] take num classes in constructor 3ccfe4a [Holden Karau] Merge in master, make pass numClasses through randomforest for training 222a10b [Holden Karau] Increase numtrees to 3 in the python test since before the two were equal and the argmax was selecting the last one 16aea1c [Holden Karau] Make tests match the new models b454a02 [Holden Karau] Make the Tree classifiers extends the Classifier base class 77b4114 [Holden Karau] Import vectors lib --- .../RandomForestClassifier.scala | 30 ++++++++++--------- .../RandomForestClassifierSuite.scala | 18 ++++++++--- python/pyspark/ml/classification.py | 4 +-- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index fc0693f67cc2e..bc19bd6df894f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} @@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType */ @Experimental final class RandomForestClassifier(override val uid: String) - extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel] + extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel] with RandomForestParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("rfc")) @@ -98,7 +98,7 @@ final class RandomForestClassifier(override val uid: String) val trees = RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed) .map(_.asInstanceOf[DecisionTreeClassificationModel]) - new RandomForestClassificationModel(trees) + new RandomForestClassificationModel(trees, numClasses) } override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) @@ -125,8 +125,9 @@ object RandomForestClassifier { @Experimental final class RandomForestClassificationModel private[ml] ( override val uid: String, - private val _trees: Array[DecisionTreeClassificationModel]) - extends PredictionModel[Vector, RandomForestClassificationModel] + private val _trees: Array[DecisionTreeClassificationModel], + override val numClasses: Int) + extends ClassificationModel[Vector, RandomForestClassificationModel] with TreeEnsembleModel with Serializable { require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") @@ -135,8 +136,8 @@ final class RandomForestClassificationModel private[ml] ( * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees */ - def this(trees: Array[DecisionTreeClassificationModel]) = - this(Identifiable.randomUID("rfc"), trees) + def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) = + this(Identifiable.randomUID("rfc"), trees, numClasses) override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] @@ -153,20 +154,20 @@ final class RandomForestClassificationModel private[ml] ( dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } - override protected def predict(features: Vector): Double = { + override protected def predictRaw(features: Vector): Vector = { // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 // Classifies using majority votes. // Ignore the weights since all are 1.0 for now. - val votes = mutable.Map.empty[Int, Double] + val votes = new Array[Double](numClasses) _trees.view.foreach { tree => val prediction = tree.rootNode.predict(features).toInt - votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight + votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } - votes.maxBy(_._2)._1 + Vectors.dense(votes) } override def copy(extra: ParamMap): RandomForestClassificationModel = { - copyValues(new RandomForestClassificationModel(uid, _trees), extra) + copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra) } override def toString: String = { @@ -185,7 +186,8 @@ private[ml] object RandomForestClassificationModel { def fromOld( oldModel: OldRandomForestModel, parent: RandomForestClassifier, - categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = { + categoricalFeatures: Map[Int, Int], + numClasses: Int): RandomForestClassificationModel = { require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") val newTrees = oldModel.trees.map { tree => @@ -193,6 +195,6 @@ private[ml] object RandomForestClassificationModel { DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") - new RandomForestClassificationModel(uid, newTrees) + new RandomForestClassificationModel(uid, newTrees, numClasses) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 1b6b69c7dc71e..ab711c8e4b215 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)))) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) ParamsSuite.checkParams(model) } @@ -167,9 +167,19 @@ private object RandomForestClassifierSuite { val newModel = rf.fit(newData) // Use parent from newTree since this is not checked anyways. val oldModelAsNew = RandomForestClassificationModel.fromOld( - oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) + oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures, + numClasses) TreeTests.checkEqual(oldModelAsNew, newModel) assert(newModel.hasParent) assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) + assert(newModel.numClasses == numClasses) + val results = newModel.transform(newData) + results.select("rawPrediction", "prediction").collect().foreach { + case Row(raw: Vector, prediction: Double) => { + assert(raw.size == numClasses) + val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2 + assert(predFromRaw == prediction) + } + } } } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 89117e492846b..5a82bc286d1e8 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) >>> td = si_model.transform(df) - >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42) + >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) - >>> allclose(model.treeWeights, [1.0, 1.0]) + >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction From 2a9fe4a4e7acbe4c9d3b6c6e61ff46d1472ee5f4 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 29 Jul 2015 18:23:07 -0700 Subject: [PATCH 143/219] [SPARK-6129] [MLLIB] [DOCS] Added user guide for evaluation metrics Author: sethah Closes #7655 from sethah/Working_on_6129 and squashes the following commits: 253db2d [sethah] removed number formatting from example code b769cab [sethah] rewording threshold section d5dad4d [sethah] adding some explanations of concepts to the eval metrics user guide 3a61ff9 [sethah] Removing unnecessary latex commands from metrics guide c9dd058 [sethah] Cleaning up and formatting metrics user guide section 6f31c21 [sethah] All example code for metrics section done 98813fe [sethah] Most java and python example code added. Further latex formatting 53a24fc [sethah] Adding documentations of metrics for ML algorithms to user guide --- docs/mllib-evaluation-metrics.md | 1497 ++++++++++++++++++++++++++++++ docs/mllib-guide.md | 1 + 2 files changed, 1498 insertions(+) create mode 100644 docs/mllib-evaluation-metrics.md diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md new file mode 100644 index 0000000000000..4ca0bb06b26a6 --- /dev/null +++ b/docs/mllib-evaluation-metrics.md @@ -0,0 +1,1497 @@ +--- +layout: global +title: Evaluation Metrics - MLlib +displayTitle: MLlib - Evaluation Metrics +--- + +* Table of contents +{:toc} + +Spark's MLlib comes with a number of machine learning algorithms that can be used to learn from and make predictions +on data. When these algorithms are applied to build machine learning models, there is a need to evaluate the performance +of the model on some criteria, which depends on the application and its requirements. Spark's MLlib also provides a +suite of metrics for the purpose of evaluating the performance of machine learning models. + +Specific machine learning algorithms fall under broader types of machine learning applications like classification, +regression, clustering, etc. Each of these types have well established metrics for performance evaluation and those +metrics that are currently available in Spark's MLlib are detailed in this section. + +## Classification model evaluation + +While there are many different types of classification algorithms, the evaluation of classification models all share +similar principles. In a [supervised classification problem](https://en.wikipedia.org/wiki/Statistical_classification), +there exists a true output and a model-generated predicted output for each data point. For this reason, the results for +each data point can be assigned to one of four categories: + +* True Positive (TP) - label is positive and prediction is also positive +* True Negative (TN) - label is negative and prediction is also negative +* False Positive (FP) - label is negative but prediction is positive +* False Negative (FN) - label is positive but prediction is negative + +These four numbers are the building blocks for most classifier evaluation metrics. A fundamental point when considering +classifier evaluation is that pure accuracy (i.e. was the prediction correct or incorrect) is not generally a good metric. The +reason for this is because a dataset may be highly unbalanced. For example, if a model is designed to predict fraud from +a dataset where 95% of the data points are _not fraud_ and 5% of the data points are _fraud_, then a naive classifier +that predicts _not fraud_, regardless of input, will be 95% accurate. For this reason, metrics like +[precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) are typically used because they take into +account the *type* of error. In most applications there is some desired balance between precision and recall, which can +be captured by combining the two into a single metric, called the [F-measure](https://en.wikipedia.org/wiki/F1_score). + +### Binary classification + +[Binary classifiers](https://en.wikipedia.org/wiki/Binary_classification) are used to separate the elements of a given +dataset into one of two possible groups (e.g. fraud or not fraud) and is a special case of multiclass classification. +Most binary classification metrics can be generalized to multiclass classification metrics. + +#### Threshold tuning + +It is import to understand that many classification models actually output a "score" (often times a probability) for +each class, where a higher score indicates higher likelihood. In the binary case, the model may output a probability for +each class: $P(Y=1|X)$ and $P(Y=0|X)$. Instead of simply taking the higher probability, there may be some cases where +the model might need to be tuned so that it only predicts a class when the probability is very high (e.g. only block a +credit card transaction if the model predicts fraud with >90% probability). Therefore, there is a prediction *threshold* +which determines what the predicted class will be based on the probabilities that the model outputs. + +Tuning the prediction threshold will change the precision and recall of the model and is an important part of model +optimization. In order to visualize how precision, recall, and other metrics change as a function of the threshold it is +common practice to plot competing metrics against one another, parameterized by threshold. A P-R curve plots (precision, +recall) points for different threshold values, while a +[receiver operating characteristic](https://en.wikipedia.org/wiki/Receiver_operating_characteristic), or ROC, curve +plots (recall, false positive rate) points. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Precision (Postive Predictive Value)$PPV=\frac{TP}{TP + FP}$
Recall (True Positive Rate)$TPR=\frac{TP}{P}=\frac{TP}{TP + FN}$
F-measure$F(\beta) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV \cdot TPR} + {\beta^2 \cdot PPV + TPR}\right)$
Receiver Operating Characteristic (ROC)$FPR(T)=\int^\infty_{T} P_0(T)\,dT \\ TPR(T)=\int^\infty_{T} P_1(T)\,dT$
Area Under ROC Curve$AUROC=\int^1_{0} \frac{TP}{P} d\left(\frac{FP}{N}\right)$
Area Under Precision-Recall Curve$AUPRC=\int^1_{0} \frac{TP}{TP+FP} d\left(\frac{TP}{P}\right)$
+ + +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a binary classification algorithm on the +data, and evaluate the performance of the algorithm by several binary evaluation metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training) + +// Clear the prediction threshold so the model will return probabilities +model.clearThreshold + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new BinaryClassificationMetrics(predictionAndLabels) + +// Precision by threshold +val precision = metrics.precisionByThreshold +precision.foreach { case (t, p) => + println(s"Threshold: $t, Precision: $p") +} + +// Recall by threshold +val recall = metrics.precisionByThreshold +recall.foreach { case (t, r) => + println(s"Threshold: $t, Recall: $r") +} + +// Precision-Recall Curve +val PRC = metrics.pr + +// F-measure +val f1Score = metrics.fMeasureByThreshold +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 1") +} + +val beta = 0.5 +val fScore = metrics.fMeasureByThreshold(beta) +f1Score.foreach { case (t, f) => + println(s"Threshold: $t, F-score: $f, Beta = 0.5") +} + +// AUPRC +val auPRC = metrics.areaUnderPR +println("Area under precision-recall curve = " + auPRC) + +// Compute thresholds used in ROC and PR curves +val thresholds = precision.map(_._1) + +// ROC Curve +val roc = metrics.roc + +// AUROC +val auROC = metrics.areaUnderROC +println("Area under ROC = " + auROC) + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class BinaryClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Binary Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_binary_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(2) + .run(training.rdd()); + + // Clear the prediction threshold so the model will return probabilities + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictionAndLabels.rdd()); + + // Precision by threshold + JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); + System.out.println("Precision by threshold: " + precision.toArray()); + + // Recall by threshold + JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); + System.out.println("Recall by threshold: " + recall.toArray()); + + // F Score by threshold + JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); + System.out.println("F1 Score by threshold: " + f1Score.toArray()); + + JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); + System.out.println("F2 Score by threshold: " + f2Score.toArray()); + + // Precision-recall curve + JavaRDD> prc = metrics.pr().toJavaRDD(); + System.out.println("Precision-recall curve: " + prc.toArray()); + + // Thresholds + JavaRDD thresholds = precision.map( + new Function, Double>() { + public Double call (Tuple2 t) { + return new Double(t._1().toString()); + } + } + ); + + // ROC Curve + JavaRDD> roc = metrics.roc().toJavaRDD(); + System.out.println("ROC curve: " + roc.toArray()); + + // AUPRC + System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); + + // AUROC + System.out.println("Area under ROC = " + metrics.areaUnderROC()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.evaluation import BinaryClassificationMetrics +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import MLUtils + +# Several of the methods available in scala are currently missing from pyspark + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_binary_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = BinaryClassificationMetrics(predictionAndLabels) + +# Area under precision-recall curve +print "Area under PR = %s" % metrics.areaUnderPR + +# Area under ROC curve +print "Area under ROC = %s" % metrics.areaUnderROC + +{% endhighlight %} + +
+
+ + +### Multiclass classification + +A [multiclass classification](https://en.wikipedia.org/wiki/Multiclass_classification) describes a classification +problem where there are $M \gt 2$ possible labels for each data point (the case where $M=2$ is the binary +classification problem). For example, classifying handwriting samples to the digits 0 to 9, having 10 possible classes. + +For multiclass metrics, the notion of positives and negatives is slightly different. Predictions and labels can still +be positive or negative, but they must be considered under the context of a particular class. Each label and prediction +take on the value of one of the multiple classes and so they are said to be positive for their particular class and negative +for all other classes. So, a true positive occurs whenever the prediction and the label match, while a true negative +occurs when neither the prediction nor the label take on the value of a given class. By this convention, there can be +multiple true negatives for a given data sample. The extension of false negatives and false positives from the former +definitions of positive and negative labels is straightforward. + +#### Label based metrics + +Opposed to binary classification where there are only two possible labels, multiclass classification problems have many +possible labels and so the concept of label-based metrics is introduced. Overall precision measures precision across all +labels - the number of times any class was predicted correctly (true positives) normalized by the number of data +points. Precision by label considers only one class, and measures the number of time a specific label was predicted +correctly normalized by the number of times that label appears in the output. + +**Available metrics** + +Define the class, or label, set as + +$$L = \{\ell_0, \ell_1, \ldots, \ell_{M-1} \} $$ + +The true output vector $\mathbf{y}$ consists of $N$ elements + +$$\mathbf{y}_0, \mathbf{y}_1, \ldots, \mathbf{y}_{N-1} \in L $$ + +A multiclass prediction algorithm generates a prediction vector $\hat{\mathbf{y}}$ of $N$ elements + +$$\hat{\mathbf{y}}_0, \hat{\mathbf{y}}_1, \ldots, \hat{\mathbf{y}}_{N-1} \in L $$ + +For this section, a modified delta function $\hat{\delta}(x)$ will prove useful + +$$\hat{\delta}(x) = \begin{cases}1 & \text{if $x = 0$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Confusion Matrix + $C_{ij} = \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_i) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_j)\\ \\ + \left( \begin{array}{ccc} + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_1) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) \\ + \vdots & \ddots & \vdots \\ + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_1) & \ldots & + \sum_{k=0}^{N-1} \hat{\delta}(\mathbf{y}_k-\ell_N) \cdot \hat{\delta}(\hat{\mathbf{y}}_k - \ell_N) + \end{array} \right)$ +
Overall Precision$PPV = \frac{TP}{TP + FP} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
Overall Recall$TPR = \frac{TP}{TP + FN} = \frac{1}{N}\sum_{i=0}^{N-1} \hat{\delta}\left(\hat{\mathbf{y}}_i - + \mathbf{y}_i\right)$
Overall F1-measure$F1 = 2 \cdot \left(\frac{PPV \cdot TPR} + {PPV + TPR}\right)$
Precision by label$PPV(\ell) = \frac{TP}{TP + FP} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell)}$
Recall by label$TPR(\ell)=\frac{TP}{P} = + \frac{\sum_{i=0}^{N-1} \hat{\delta}(\hat{\mathbf{y}}_i - \ell) \cdot \hat{\delta}(\mathbf{y}_i - \ell)} + {\sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i - \ell)}$
F-measure by label$F(\beta, \ell) = \left(1 + \beta^2\right) \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {\beta^2 \cdot PPV(\ell) + TPR(\ell)}\right)$
Weighted precision$PPV_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} PPV(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
Weighted recall$TPR_{w}= \frac{1}{N} \sum\nolimits_{\ell \in L} TPR(\ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
Weighted F-measure$F_{w}(\beta)= \frac{1}{N} \sum\nolimits_{\ell \in L} F(\beta, \ell) + \cdot \sum_{i=0}^{N-1} \hat{\delta}(\mathbf{y}_i-\ell)$
+ +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a multiclass classification algorithm on +the data, and evaluate the performance of the algorithm by several multiclass classification evaluation metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// Load training data in LIBSVM format +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +// Split data into training (60%) and test (40%) +val Array(training, test) = data.randomSplit(Array(0.6, 0.4), seed = 11L) +training.cache() + +// Run training algorithm to build the model +val model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training) + +// Compute raw scores on the test set +val predictionAndLabels = test.map { case LabeledPoint(label, features) => + val prediction = model.predict(features) + (prediction, label) +} + +// Instantiate metrics object +val metrics = new MulticlassMetrics(predictionAndLabels) + +// Confusion matrix +println("Confusion matrix:") +println(metrics.confusionMatrix) + +// Overall Statistics +val precision = metrics.precision +val recall = metrics.recall // same as true positive rate +val f1Score = metrics.fMeasure +println("Summary Statistics") +println(s"Precision = $precision") +println(s"Recall = $recall") +println(s"F1 Score = $f1Score") + +// Precision by label +val labels = metrics.labels +labels.foreach { l => + println(s"Precision($l) = " + metrics.precision(l)) +} + +// Recall by label +labels.foreach { l => + println(s"Recall($l) = " + metrics.recall(l)) +} + +// False positive rate by label +labels.foreach { l => + println(s"FPR($l) = " + metrics.falsePositiveRate(l)) +} + +// F-measure by label +labels.foreach { l => + println(s"F1-Score($l) = " + metrics.fMeasure(l)) +} + +// Weighted stats +println(s"Weighted precision: ${metrics.weightedPrecision}") +println(s"Weighted recall: ${metrics.weightedRecall}") +println(s"Weighted F1 score: ${metrics.weightedFMeasure}") +println(s"Weighted false positive rate: ${metrics.weightedFalsePositiveRate}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; +import org.apache.spark.mllib.evaluation.MulticlassMetrics; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; + +public class MulticlassClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multiclass Classification Metrics"); + SparkContext sc = new SparkContext(conf); + String path = "data/mllib/sample_multiclass_classification_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L); + JavaRDD training = splits[0].cache(); + JavaRDD test = splits[1]; + + // Run training algorithm to build the model. + final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + .setNumClasses(3) + .run(training.rdd()); + + // Compute raw scores on the test set. + JavaRDD> predictionAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double prediction = model.predict(p.features()); + return new Tuple2(prediction, p.label()); + } + } + ); + + // Get evaluation metrics. + MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); + + // Confusion matrix + Matrix confusion = metrics.confusionMatrix(); + System.out.println("Confusion matrix: \n" + confusion); + + // Overall statistics + System.out.println("Precision = " + metrics.precision()); + System.out.println("Recall = " + metrics.recall()); + System.out.println("F1 Score = " + metrics.fMeasure()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length; i++) { + System.out.format("Class %f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(metrics.labels()[i])); + } + + //Weighted stats + System.out.format("Weighted precision = %f\n", metrics.weightedPrecision()); + System.out.format("Weighted recall = %f\n", metrics.weightedRecall()); + System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure()); + System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate()); + + // Save and load model + model.save(sc, "myModelPath"); + LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.util import MLUtils +from pyspark.mllib.evaluation import MulticlassMetrics + +# Load training data in LIBSVM format +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_multiclass_classification_data.txt") + +# Split data into training (60%) and test (40%) +training, test = data.randomSplit([0.6, 0.4], seed = 11L) +training.cache() + +# Run training algorithm to build the model +model = LogisticRegressionWithLBFGS.train(training, numClasses=3) + +# Compute raw scores on the test set +predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp.label)) + +# Instantiate metrics object +metrics = MulticlassMetrics(predictionAndLabels) + +# Overall statistics +precision = metrics.precision() +recall = metrics.recall() +f1Score = metrics.fMeasure() +print "Summary Stats" +print "Precision = %s" % precision +print "Recall = %s" % recall +print "F1 Score = %s" % f1Score + +# Statistics by class +labels = data.map(lambda lp: lp.label).distinct().collect() +for label in sorted(labels): + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)) + +# Weighted stats +print "Weighted recall = %s" % metrics.weightedRecall +print "Weighted precision = %s" % metrics.weightedPrecision +print "Weighted F(1) Score = %s" % metrics.weightedFMeasure() +print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5) +print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate +{% endhighlight %} + +
+
+ +### Multilabel classification + +A [multilabel classification](https://en.wikipedia.org/wiki/Multi-label_classification) problem involves mapping +each sample in a dataset to a set of class labels. In this type of classification problem, the labels are not +mutually exclusive. For example, when classifying a set of news articles into topics, a single article might be both +science and politics. + +Because the labels are not mutually exclusive, the predictions and true labels are now vectors of label *sets*, rather +than vectors of labels. Multilabel metrics, therefore, extend the fundamental ideas of precision, recall, etc. to +operations on sets. For example, a true positive for a given class now occurs when that class exists in the predicted +set and it exists in the true label set, for a specific data point. + +**Available metrics** + +Here we define a set $D$ of $N$ documents + +$$D = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +Define $L_0, L_1, ..., L_{N-1}$ to be a family of label sets and $P_0, P_1, ..., P_{N-1}$ +to be a family of prediction sets where $L_i$ and $P_i$ are the label set and prediction set, respectively, that +correspond to document $d_i$. + +The set of all unique labels is given by + +$$L = \bigcup_{k=0}^{N-1} L_k$$ + +The following definition of indicator function $I_A(x)$ on a set $A$ will be necessary + +$$I_A(x) = \begin{cases}1 & \text{if $x \in A$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Precision$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|P_i \cap L_i\right|}{\left|P_i\right|}$
Recall$\frac{1}{N} \sum_{i=0}^{N-1} \frac{\left|L_i \cap P_i\right|}{\left|L_i\right|}$
Accuracy + $\frac{1}{N} \sum_{i=0}^{N - 1} \frac{\left|L_i \cap P_i \right|} + {\left|L_i\right| + \left|P_i\right| - \left|L_i \cap P_i \right|}$ +
Precision by label$PPV(\ell)=\frac{TP}{TP + FP}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{P_i}(\ell)}$
Recall by label$TPR(\ell)=\frac{TP}{P}= + \frac{\sum_{i=0}^{N-1} I_{P_i}(\ell) \cdot I_{L_i}(\ell)} + {\sum_{i=0}^{N-1} I_{L_i}(\ell)}$
F1-measure by label$F1(\ell) = 2 + \cdot \left(\frac{PPV(\ell) \cdot TPR(\ell)} + {PPV(\ell) + TPR(\ell)}\right)$
Hamming Loss + $\frac{1}{N \cdot \left|L\right|} \sum_{i=0}^{N - 1} \left|L_i\right| + \left|P_i\right| - 2\left|L_i + \cap P_i\right|$ +
Subset Accuracy$\frac{1}{N} \sum_{i=0}^{N-1} I_{\{L_i\}}(P_i)$
F1 Measure$\frac{1}{N} \sum_{i=0}^{N-1} 2 \frac{\left|P_i \cap L_i\right|}{\left|P_i\right| \cdot \left|L_i\right|}$
Micro precision$\frac{TP}{TP + FP}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|P_i - L_i\right|}$
Micro recall$\frac{TP}{TP + FN}=\frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|} + {\sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right|}$
Micro F1 Measure + $2 \cdot \frac{TP}{2 \cdot TP + FP + FN}=2 \cdot \frac{\sum_{i=0}^{N-1} \left|P_i \cap L_i\right|}{2 \cdot + \sum_{i=0}^{N-1} \left|P_i \cap L_i\right| + \sum_{i=0}^{N-1} \left|L_i - P_i\right| + \sum_{i=0}^{N-1} + \left|P_i - L_i\right|}$ +
+ +**Examples** + +The following code snippets illustrate how to evaluate the performance of a multilabel classifer. The examples +use the fake prediction and label data for multilabel classification that is shown below. + +Document predictions: + +* doc 0 - predict 0, 1 - class 0, 2 +* doc 1 - predict 0, 2 - class 0, 1 +* doc 2 - predict none - class 0 +* doc 3 - predict 2 - class 2 +* doc 4 - predict 2, 0 - class 2, 0 +* doc 5 - predict 0, 1, 2 - class 0, 1 +* doc 6 - predict 1 - class 1, 2 + +Predicted classes: + +* class 0 - doc 0, 1, 4, 5 (total 4) +* class 1 - doc 0, 5, 6 (total 3) +* class 2 - doc 1, 3, 4, 5 (total 4) + +True classes: + +* class 0 - doc 0, 1, 2, 4, 5 (total 5) +* class 1 - doc 1, 5, 6 (total 3) +* class 2 - doc 0, 3, 4, 6 (total 4) + +
+ +
+ +{% highlight scala %} +import org.apache.spark.mllib.evaluation.MultilabelMetrics +import org.apache.spark.rdd.RDD; + +val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + +// Instantiate metrics object +val metrics = new MultilabelMetrics(scoreAndLabels) + +// Summary stats +println(s"Recall = ${metrics.recall}") +println(s"Precision = ${metrics.precision}") +println(s"F1 measure = ${metrics.f1Measure}") +println(s"Accuracy = ${metrics.accuracy}") + +// Individual label stats +metrics.labels.foreach(label => println(s"Class $label precision = ${metrics.precision(label)}")) +metrics.labels.foreach(label => println(s"Class $label recall = ${metrics.recall(label)}")) +metrics.labels.foreach(label => println(s"Class $label F1-score = ${metrics.f1Measure(label)}")) + +// Micro stats +println(s"Micro recall = ${metrics.microRecall}") +println(s"Micro precision = ${metrics.microPrecision}") +println(s"Micro F1 measure = ${metrics.microF1Measure}") + +// Hamming loss +println(s"Hamming loss = ${metrics.hammingLoss}") + +// Subset accuracy +println(s"Subset accuracy = ${metrics.subsetAccuracy}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.evaluation.MultilabelMetrics; +import org.apache.spark.SparkConf; +import java.util.Arrays; +import java.util.List; + +public class MultilabelClassification { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Multilabel Classification Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + + List> data = Arrays.asList( + new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{}, new double[]{0.0}), + new Tuple2(new double[]{2.0}, new double[]{2.0}), + new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + ); + JavaRDD> scoreAndLabels = sc.parallelize(data); + + // Instantiate metrics object + MultilabelMetrics metrics = new MultilabelMetrics(scoreAndLabels.rdd()); + + // Summary stats + System.out.format("Recall = %f\n", metrics.recall()); + System.out.format("Precision = %f\n", metrics.precision()); + System.out.format("F1 measure = %f\n", metrics.f1Measure()); + System.out.format("Accuracy = %f\n", metrics.accuracy()); + + // Stats by labels + for (int i = 0; i < metrics.labels().length - 1; i++) { + System.out.format("Class %1.1f precision = %f\n", metrics.labels()[i], metrics.precision(metrics.labels()[i])); + System.out.format("Class %1.1f recall = %f\n", metrics.labels()[i], metrics.recall(metrics.labels()[i])); + System.out.format("Class %1.1f F1 score = %f\n", metrics.labels()[i], metrics.f1Measure(metrics.labels()[i])); + } + + // Micro stats + System.out.format("Micro recall = %f\n", metrics.microRecall()); + System.out.format("Micro precision = %f\n", metrics.microPrecision()); + System.out.format("Micro F1 measure = %f\n", metrics.microF1Measure()); + + // Hamming loss + System.out.format("Hamming loss = %f\n", metrics.hammingLoss()); + + // Subset accuracy + System.out.format("Subset accuracy = %f\n", metrics.subsetAccuracy()); + + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.evaluation import MultilabelMetrics + +scoreAndLabels = sc.parallelize([ + ([0.0, 1.0], [0.0, 2.0]), + ([0.0, 2.0], [0.0, 1.0]), + ([], [0.0]), + ([2.0], [2.0]), + ([2.0, 0.0], [2.0, 0.0]), + ([0.0, 1.0, 2.0], [0.0, 1.0]), + ([1.0], [1.0, 2.0])]) + +# Instantiate metrics object +metrics = MultilabelMetrics(scoreAndLabels) + +# Summary stats +print "Recall = %s" % metrics.recall() +print "Precision = %s" % metrics.precision() +print "F1 measure = %s" % metrics.f1Measure() +print "Accuracy = %s" % metrics.accuracy + +# Individual label stats +labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() +for label in labels: + print "Class %s precision = %s" % (label, metrics.precision(label)) + print "Class %s recall = %s" % (label, metrics.recall(label)) + print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)) + +# Micro stats +print "Micro precision = %s" % metrics.microPrecision +print "Micro recall = %s" % metrics.microRecall +print "Micro F1 measure = %s" % metrics.microF1Measure + +# Hamming loss +print "Hamming loss = %s" % metrics.hammingLoss + +# Subset accuracy +print "Subset accuracy = %s" % metrics.subsetAccuracy + +{% endhighlight %} + +
+
+ +### Ranking systems + +The role of a ranking algorithm (often thought of as a [recommender system](https://en.wikipedia.org/wiki/Recommender_system)) +is to return to the user a set of relevant items or documents based on some training data. The definition of relevance +may vary and is usually application specific. Ranking system metrics aim to quantify the effectiveness of these +rankings or recommendations in various contexts. Some metrics compare a set of recommended documents to a ground truth +set of relevant documents, while other metrics may incorporate numerical ratings explicitly. + +**Available metrics** + +A ranking system usually deals with a set of $M$ users + +$$U = \left\{u_0, u_1, ..., u_{M-1}\right\}$$ + +Each user ($u_i$) having a set of $N$ ground truth relevant documents + +$$D_i = \left\{d_0, d_1, ..., d_{N-1}\right\}$$ + +And a list of $Q$ recommended documents, in order of decreasing relevance + +$$R_i = \left[r_0, r_1, ..., r_{Q-1}\right]$$ + +The goal of the ranking system is to produce the most relevant set of documents for each user. The relevance of the +sets and the effectiveness of the algorithms can be measured using the metrics listed below. + +It is necessary to define a function which, provided a recommended document and a set of ground truth relevant +documents, returns a relevance score for the recommended document. + +$$rel_D(r) = \begin{cases}1 & \text{if $r \in D$}, \\ 0 & \text{otherwise}.\end{cases}$$ + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinitionNotes
+ Precision at k + + $p(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{k} \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} rel_{D_i}(R_i(j))}$ + + Precision at k is a measure of + how many of the first k recommended documents are in the set of true relevant documents averaged across all + users. In this metric, the order of the recommendations is not taken into account. +
Mean Average Precision + $MAP=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{\left|D_i\right|} \sum_{j=0}^{Q-1} \frac{rel_{D_i}(R_i(j))}{j + 1}}$ + + MAP is a measure of how + many of the recommended documents are in the set of true relevant documents, where the + order of the recommendations is taken into account (i.e. penalty for highly relevant documents is higher). +
Normalized Discounted Cumulative Gain + $NDCG(k)=\frac{1}{M} \sum_{i=0}^{M-1} {\frac{1}{IDCG(D_i, k)}\sum_{j=0}^{n-1} + \frac{rel_{D_i}(R_i(j))}{\text{ln}(j+1)}} \\ + \text{Where} \\ + \hspace{5 mm} n = \text{min}\left(\text{max}\left(|R_i|,|D_i|\right),k\right) \\ + \hspace{5 mm} IDCG(D, k) = \sum_{j=0}^{\text{min}(\left|D\right|, k) - 1} \frac{1}{\text{ln}(j+1)}$ + + NDCG at k is a + measure of how many of the first k recommended documents are in the set of true relevant documents averaged + across all users. In contrast to precision at k, this metric takes into account the order of the recommendations + (documents are assumed to be in order of decreasing relevance). +
+ +**Examples** + +The following code snippets illustrate how to load a sample dataset, train an alternating least squares recommendation +model on the data, and evaluate the performance of the recommender by several ranking metrics. A brief summary of the +methodology is provided below. + +MovieLens ratings are on a scale of 1-5: + + * 5: Must see + * 4: Will enjoy + * 3: It's okay + * 2: Fairly bad + * 1: Awful + +So we should not recommend a movie if the predicted rating is less than 3. +To map ratings to confidence scores, we use: + + * 5 -> 2.5 + * 4 -> 1.5 + * 3 -> 0.5 + * 2 -> -0.5 + * 1 -> -1.5. + +This mappings means unobserved entries are generally between It's okay and Fairly bad. The semantics of 0 in this +expanded world of non-positive weights are "the same as never having interacted at all." + +
+ +
+ +{% highlight scala %} +import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.recommendation.{ALS, Rating} + +// Read in the ratings data +val ratings = sc.textFile("data/mllib/sample_movielens_data.txt").map { line => + val fields = line.split("::") + Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble - 2.5) +}.cache() + +// Map ratings to 1 or 0, 1 indicating a movie that should be recommended +val binarizedRatings = ratings.map(r => Rating(r.user, r.product, if (r.rating > 0) 1.0 else 0.0)).cache() + +// Summarize ratings +val numRatings = ratings.count() +val numUsers = ratings.map(_.user).distinct().count() +val numMovies = ratings.map(_.product).distinct().count() +println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.") + +// Build the model +val numIterations = 10 +val rank = 10 +val lambda = 0.01 +val model = ALS.train(ratings, rank, numIterations, lambda) + +// Define a function to scale ratings from 0 to 1 +def scaledRating(r: Rating): Rating = { + val scaledRating = math.max(math.min(r.rating, 1.0), 0.0) + Rating(r.user, r.product, scaledRating) +} + +// Get sorted top ten predictions for each user and then scale from [0, 1] +val userRecommended = model.recommendProductsForUsers(10).map{ case (user, recs) => + (user, recs.map(scaledRating)) +} + +// Assume that any movie a user rated 3 or higher (which maps to a 1) is a relevant document +// Compare with top ten most relevant documents +val userMovies = binarizedRatings.groupBy(_.user) +val relevantDocuments = userMovies.join(userRecommended).map{ case (user, (actual, predictions)) => + (predictions.map(_.product), actual.filter(_.rating > 0.0).map(_.product).toArray) +} + +// Instantiate metrics object +val metrics = new RankingMetrics(relevantDocuments) + +// Precision at K +Array(1, 3, 5).foreach{ k => + println(s"Precision at $k = ${metrics.precisionAt(k)}") +} + +// Mean average precision +println(s"Mean average precision = ${metrics.meanAveragePrecision}") + +// Normalized discounted cumulative gain +Array(1, 3, 5).foreach{ k => + println(s"NDCG at $k = ${metrics.ndcgAt(k)}") +} + +// Get predictions for each data point +val allPredictions = model.predict(ratings.map(r => (r.user, r.product))).map(r => ((r.user, r.product), r.rating)) +val allRatings = ratings.map(r => ((r.user, r.product), r.rating)) +val predictionsAndLabels = allPredictions.join(allRatings).map{ case ((user, product), (predicted, actual)) => + (predicted, actual) +} + +// Get the RMSE using regression metrics +val regressionMetrics = new RegressionMetrics(predictionsAndLabels) +println(s"RMSE = ${regressionMetrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${regressionMetrics.r2}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.rdd.RDD; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.Function; +import java.util.*; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.mllib.evaluation.RankingMetrics; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.Rating; + +// Read in the ratings data +public class Ranking { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Ranking Metrics"); + JavaSparkContext sc = new JavaSparkContext(conf); + String path = "data/mllib/sample_movielens_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String line) { + String[] parts = line.split("::"); + return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double.parseDouble(parts[2]) - 2.5); + } + } + ); + ratings.cache(); + + // Train an ALS model + final MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), 10, 10, 0.01); + + // Get top 10 recommendations for every user and scale ratings from 0 to 1 + JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); + JavaRDD> userRecsScaled = userRecs.map( + new Function, Tuple2>() { + public Tuple2 call(Tuple2 t) { + Rating[] scaledRatings = new Rating[t._2().length]; + for (int i = 0; i < scaledRatings.length; i++) { + double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); + scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); + } + return new Tuple2(t._1(), scaledRatings); + } + } + ); + JavaPairRDD userRecommended = JavaPairRDD.fromJavaRDD(userRecsScaled); + + // Map ratings to 1 or 0, 1 indicating a movie that should be recommended + JavaRDD binarizedRatings = ratings.map( + new Function() { + public Rating call(Rating r) { + double binaryRating; + if (r.rating() > 0.0) { + binaryRating = 1.0; + } + else { + binaryRating = 0.0; + } + return new Rating(r.user(), r.product(), binaryRating); + } + } + ); + + // Group ratings by common user + JavaPairRDD> userMovies = binarizedRatings.groupBy( + new Function() { + public Object call(Rating r) { + return r.user(); + } + } + ); + + // Get true relevant documents from all user ratings + JavaPairRDD> userMoviesList = userMovies.mapValues( + new Function, List>() { + public List call(Iterable docs) { + List products = new ArrayList(); + for (Rating r : docs) { + if (r.rating() > 0.0) { + products.add(r.product()); + } + } + return products; + } + } + ); + + // Extract the product id from each recommendation + JavaPairRDD> userRecommendedList = userRecommended.mapValues( + new Function>() { + public List call(Rating[] docs) { + List products = new ArrayList(); + for (Rating r : docs) { + products.add(r.product()); + } + return products; + } + } + ); + JavaRDD, List>> relevantDocs = userMoviesList.join(userRecommendedList).values(); + + // Instantiate the metrics object + RankingMetrics metrics = RankingMetrics.of(relevantDocs); + + // Precision and NDCG at k + Integer[] kVector = {1, 3, 5}; + for (Integer k : kVector) { + System.out.format("Precision at %d = %f\n", k, metrics.precisionAt(k)); + System.out.format("NDCG at %d = %f\n", k, metrics.ndcgAt(k)); + } + + // Mean average precision + System.out.format("Mean average precision = %f\n", metrics.meanAveragePrecision()); + + // Evaluate the model using numerical ratings and regression metrics + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Object>>() { + public Tuple2, Object> call(Rating r){ + return new Tuple2, Object>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + + // Create regression metrics object + RegressionMetrics regressionMetrics = new RegressionMetrics(ratesAndPreds.rdd()); + + // Root mean squared error + System.out.format("RMSE = %f\n", regressionMetrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R-squared = %f\n", regressionMetrics.r2()); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.recommendation import ALS, Rating +from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics + +# Read in the ratings data +lines = sc.textFile("data/mllib/sample_movielens_data.txt") + +def parseLine(line): + fields = line.split("::") + return Rating(int(fields[0]), int(fields[1]), float(fields[2]) - 2.5) +ratings = lines.map(lambda r: parseLine(r)) + +# Train a model on to predict user-product ratings +model = ALS.train(ratings, 10, 10, 0.01) + +# Get predicted ratings on all existing user-product pairs +testData = ratings.map(lambda p: (p.user, p.product)) +predictions = model.predictAll(testData).map(lambda r: ((r.user, r.product), r.rating)) + +ratingsTuple = ratings.map(lambda r: ((r.user, r.product), r.rating)) +scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) + +# Instantiate regression metrics to compare predicted and actual ratings +metrics = RegressionMetrics(scoreAndLabels) + +# Root mean sqaured error +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +{% endhighlight %} + +
+
+ +## Regression model evaluation + +[Regression analysis](https://en.wikipedia.org/wiki/Regression_analysis) is used when predicting a continuous output +variable from a number of independent variables. + +**Available metrics** + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MetricDefinition
Mean Squared Error (MSE)$MSE = \frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}$
Root Mean Squared Error (RMSE)$RMSE = \sqrt{\frac{\sum_{i=0}^{N-1} (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{N}}$
Mean Absoloute Error (MAE)$MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$
Coefficient of Determination $(R^2)$$R^2=1 - \frac{MSE}{\text{VAR}(\mathbf{y}) \cdot (N-1)}=1-\frac{\sum_{i=0}^{N-1} + (\mathbf{y}_i - \hat{\mathbf{y}}_i)^2}{\sum_{i=0}^{N-1}(\mathbf{y}_i-\bar{\mathbf{y}})^2}$
Explained Variance$1 - \frac{\text{VAR}(\mathbf{y} - \mathbf{\hat{y}})}{\text{VAR}(\mathbf{y})}$
+ +**Examples** + +
+The following code snippets illustrate how to load a sample dataset, train a linear regression algorithm on the data, +and evaluate the performance of the algorithm by several regression metrics. + +
+ +{% highlight scala %} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.regression.LinearRegressionModel +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.util.MLUtils + +// Load the data +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_linear_regression_data.txt").cache() + +// Build the model +val numIterations = 100 +val model = LinearRegressionWithSGD.train(data, numIterations) + +// Get predictions +val valuesAndPreds = data.map{ point => + val prediction = model.predict(point.features) + (prediction, point.label) +} + +// Instantiate metrics object +val metrics = new RegressionMetrics(valuesAndPreds) + +// Squared error +println(s"MSE = ${metrics.meanSquaredError}") +println(s"RMSE = ${metrics.rootMeanSquaredError}") + +// R-squared +println(s"R-squared = ${metrics.r2}") + +// Mean absolute error +println(s"MAE = ${metrics.meanAbsoluteError}") + +// Explained variance +println(s"Explained variance = ${metrics.explainedVariance}") + +{% endhighlight %} + +
+ +
+ +{% highlight java %} +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.regression.LinearRegressionModel; +import org.apache.spark.mllib.regression.LinearRegressionWithSGD; +import org.apache.spark.mllib.evaluation.RegressionMetrics; +import org.apache.spark.SparkConf; + +public class LinearRegression { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("Linear Regression Example"); + JavaSparkContext sc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/sample_linear_regression_data.txt"; + JavaRDD data = sc.textFile(path); + JavaRDD parsedData = data.map( + new Function() { + public LabeledPoint call(String line) { + String[] parts = line.split(" "); + double[] v = new double[parts.length - 1]; + for (int i = 1; i < parts.length - 1; i++) + v[i - 1] = Double.parseDouble(parts[i].split(":")[1]); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + } + } + ); + parsedData.cache(); + + // Building the model + int numIterations = 100; + final LinearRegressionModel model = + LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations); + + // Evaluate model on training examples and compute training error + JavaRDD> valuesAndPreds = parsedData.map( + new Function>() { + public Tuple2 call(LabeledPoint point) { + double prediction = model.predict(point.features()); + return new Tuple2(prediction, point.label()); + } + } + ); + + // Instantiate metrics object + RegressionMetrics metrics = new RegressionMetrics(valuesAndPreds.rdd()); + + // Squared error + System.out.format("MSE = %f\n", metrics.meanSquaredError()); + System.out.format("RMSE = %f\n", metrics.rootMeanSquaredError()); + + // R-squared + System.out.format("R Squared = %f\n", metrics.r2()); + + // Mean absolute error + System.out.format("MAE = %f\n", metrics.meanAbsoluteError()); + + // Explained variance + System.out.format("Explained Variance = %f\n", metrics.explainedVariance()); + + // Save and load model + model.save(sc.sc(), "myModelPath"); + LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath"); + } +} + +{% endhighlight %} + +
+ +
+ +{% highlight python %} +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.evaluation import RegressionMetrics +from pyspark.mllib.linalg import DenseVector + +# Load and parse the data +def parsePoint(line): + values = line.split() + return LabeledPoint(float(values[0]), DenseVector([float(x.split(':')[1]) for x in values[1:]])) + +data = sc.textFile("data/mllib/sample_linear_regression_data.txt") +parsedData = data.map(parsePoint) + +# Build the model +model = LinearRegressionWithSGD.train(parsedData) + +# Get predictions +valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.label)) + +# Instantiate metrics object +metrics = RegressionMetrics(valuesAndPreds) + +# Squared Error +print "MSE = %s" % metrics.meanSquaredError +print "RMSE = %s" % metrics.rootMeanSquaredError + +# R-squared +print "R-squared = %s" % metrics.r2 + +# Mean absolute error +print "MAE = %s" % metrics.meanAbsoluteError + +# Explained variance +print "Explained variance = %s" % metrics.explainedVariance + +{% endhighlight %} + +
+
\ No newline at end of file diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index d2d1cc93fe006..eea864eacf7c4 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -48,6 +48,7 @@ This lists functionality included in `spark.mllib`, the main MLlib API. * [Feature extraction and transformation](mllib-feature-extraction.html) * [Frequent pattern mining](mllib-frequent-pattern-mining.html) * FP-growth +* [Evaluation Metrics](mllib-evaluation-metrics.html) * [Optimization (developer)](mllib-optimization.html) * stochastic gradient descent * limited-memory BFGS (L-BFGS) From a200e64561c8803731578267df16906f6773cbea Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 29 Jul 2015 19:02:15 -0700 Subject: [PATCH 144/219] [SPARK-9440] [MLLIB] Add hyperparameters to LocalLDAModel save/load jkbradley MechCoder Resolves blocking issue for SPARK-6793. Please review after #7705 is merged. Author: Feynman Liang Closes #7757 from feynmanliang/SPARK-9940-localSaveLoad and squashes the following commits: d0d8cf4 [Feynman Liang] Fix thisClassName 0f30109 [Feynman Liang] Fix tests after changing LDAModel public API dc61981 [Feynman Liang] Add hyperparams to LocalLDAModel save/load --- .../spark/mllib/clustering/LDAModel.scala | 40 +++++++++++++------ .../spark/mllib/clustering/LDASuite.scala | 6 ++- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 059b52ef20a98..ece28848aa02c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -215,7 +215,8 @@ class LocalLDAModel private[clustering] ( override protected def formatVersion = "1.0" override def save(sc: SparkContext, path: String): Unit = { - LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix) + LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, + gammaShape) } // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -312,16 +313,23 @@ object LocalLDAModel extends Loader[LocalLDAModel] { // as a Row in data. case class Data(topic: Vector, index: Int) - // TODO: explicitly save docConcentration, topicConcentration, and gammaShape for use in - // model.predict() - def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { + def save( + sc: SparkContext, + path: String, + topicsMatrix: Matrix, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): Unit = { val sqlContext = SQLContext.getOrCreate(sc) import sqlContext.implicits._ val k = topicsMatrix.numCols val metadata = compact(render (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ - ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows))) + ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows) ~ + ("docConcentration" -> docConcentration.toArray.toSeq) ~ + ("topicConcentration" -> topicConcentration) ~ + ("gammaShape" -> gammaShape))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix @@ -331,7 +339,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] { sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path)) } - def load(sc: SparkContext, path: String): LocalLDAModel = { + def load( + sc: SparkContext, + path: String, + docConcentration: Vector, + topicConcentration: Double, + gammaShape: Double): LocalLDAModel = { val dataPath = Loader.dataPath(path) val sqlContext = SQLContext.getOrCreate(sc) val dataFrame = sqlContext.read.parquet(dataPath) @@ -348,8 +361,7 @@ object LocalLDAModel extends Loader[LocalLDAModel] { val topicsMat = Matrices.fromBreeze(brzTopics) // TODO: initialize with docConcentration, topicConcentration, and gammaShape after SPARK-9940 - new LocalLDAModel(topicsMat, - Vectors.dense(Array.fill(topicsMat.numRows)(1.0 / topicsMat.numRows)), 1D, 100D) + new LocalLDAModel(topicsMat, docConcentration, topicConcentration, gammaShape) } } @@ -358,11 +370,15 @@ object LocalLDAModel extends Loader[LocalLDAModel] { implicit val formats = DefaultFormats val expectedK = (metadata \ "k").extract[Int] val expectedVocabSize = (metadata \ "vocabSize").extract[Int] + val docConcentration = + Vectors.dense((metadata \ "docConcentration").extract[Seq[Double]].toArray) + val topicConcentration = (metadata \ "topicConcentration").extract[Double] + val gammaShape = (metadata \ "gammaShape").extract[Double] val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => - SaveLoadV1_0.load(sc, path) + SaveLoadV1_0.load(sc, path, docConcentration, topicConcentration, gammaShape) case _ => throw new Exception( s"LocalLDAModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $loadedVersion). Supported:\n" + @@ -565,7 +581,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val thisFormatVersion = "1.0" - val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel" + val thisClassName = "org.apache.spark.mllib.clustering.DistributedLDAModel" // Store globalTopicTotals as a Vector. case class Data(globalTopicTotals: Vector) @@ -591,7 +607,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { import sqlContext.implicits._ val metadata = compact(render - (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~ + (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration.toArray.toSeq) ~ ("topicConcentration" -> topicConcentration) ~ @@ -660,7 +676,7 @@ object DistributedLDAModel extends Loader[DistributedLDAModel] { val topicConcentration = (metadata \ "topicConcentration").extract[Double] val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]] val gammaShape = (metadata \ "gammaShape").extract[Double] - val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + val classNameV1_0 = SaveLoadV1_0.thisClassName val model = (loadedClassName, loadedVersion) match { case (className, "1.0") if className == classNameV1_0 => { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index aa36336ebbee6..b91c7cefed22e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -334,7 +334,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("model save/load") { // Test for LocalLDAModel. val localModel = new LocalLDAModel(tinyTopics, - Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D) + Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D) val tempDir1 = Utils.createTempDir() val path1 = tempDir1.toURI.toString @@ -360,6 +360,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(samelocalModel.topicsMatrix === localModel.topicsMatrix) assert(samelocalModel.k === localModel.k) assert(samelocalModel.vocabSize === localModel.vocabSize) + assert(samelocalModel.docConcentration === localModel.docConcentration) + assert(samelocalModel.topicConcentration === localModel.topicConcentration) + assert(samelocalModel.gammaShape === localModel.gammaShape) val sameDistributedModel = DistributedLDAModel.load(sc, path2) assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix) @@ -368,6 +371,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) assert(distributedModel.docConcentration === sameDistributedModel.docConcentration) assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration) + assert(distributedModel.gammaShape === sameDistributedModel.gammaShape) assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals) val graph = distributedModel.graph From 9514d874f0cf61f1eb4ec4f5f66e053119f769c9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 20:46:03 -0700 Subject: [PATCH 145/219] [SPARK-9458] Avoid object allocation in prefix generation. In our existing sort prefix generation code, we use expression's eval method to generate the prefix, which results in object allocation for every prefix. We can use the specialized getters available on InternalRow directly to avoid the object allocation. I also removed the FLOAT prefix, opting for converting float directly to double. Author: Reynold Xin Closes #7763 from rxin/sort-prefix and squashes the following commits: 5dc2f06 [Reynold Xin] [SPARK-9458] Avoid object allocation in prefix generation. --- .../unsafe/sort/PrefixComparators.java | 16 ------ .../unsafe/sort/PrefixComparatorsSuite.scala | 12 ----- .../execution/UnsafeExternalRowSorter.java | 2 +- .../spark/sql/execution/SortPrefixUtils.scala | 51 +++++++++---------- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../org/apache/spark/sql/execution/sort.scala | 5 +- .../execution/RowFormatConvertersSuite.scala | 2 +- .../execution/UnsafeExternalSortSuite.scala | 10 ++-- 8 files changed, 35 insertions(+), 67 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index bf1bc5dffba78..5624e067da2cc 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -31,7 +31,6 @@ private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); - public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); public static final class StringPrefixComparator extends PrefixComparator { @@ -78,21 +77,6 @@ public int compare(long a, long b) { public final long NULL_PREFIX = Long.MIN_VALUE; } - public static final class FloatPrefixComparator extends PrefixComparator { - @Override - public int compare(long aPrefix, long bPrefix) { - float a = Float.intBitsToFloat((int) aPrefix); - float b = Float.intBitsToFloat((int) bPrefix); - return Utils.nanSafeCompareFloats(a, b); - } - - public long computePrefix(float value) { - return Float.floatToIntBits(value) & 0xffffffffL; - } - - public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); - } - public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index dc03e374b51db..28fe9259453a6 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -48,18 +48,6 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } - test("float prefix comparator handles NaN properly") { - val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) - val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) - assert(nan1.isNaN) - assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) - val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) - assert(nan1Prefix === nan2Prefix) - val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) - assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) - } - test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 4c3f2c6557140..8342833246f7d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -121,7 +121,7 @@ Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2dee3542d6101..050d27f1460fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.expressions.{BoundReference, SortOrder} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -39,57 +39,54 @@ object SortPrefixUtils { sortOrder.dataType match { case StringType => PrefixComparators.STRING case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType => PrefixComparators.FLOAT - case DoubleType => PrefixComparators.DOUBLE + case FloatType | DoubleType => PrefixComparators.DOUBLE case _ => NoOpPrefixComparator } } def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { + val bound = sortOrder.child.asInstanceOf[BoundReference] + val pos = bound.ordinal sortOrder.dataType match { - case StringType => (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) - } + case StringType => + (row: InternalRow) => { + PrefixComparators.STRING.computePrefix(row.getUTF8String(pos)) + } case BooleanType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX + else if (row.getBoolean(pos)) 1 else 0 } case ByteType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Byte] + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getByte(pos) } case ShortType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Short] + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getShort(pos) } case IntegerType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Int] + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getInt(pos) } case LongType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Long] + if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getLong(pos) } case FloatType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX - else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) + if (row.isNullAt(pos)) { + PrefixComparators.DOUBLE.NULL_PREFIX + } else { + PrefixComparators.DOUBLE.computePrefix(row.getFloat(pos).toDouble) + } } case DoubleType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX - else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) + if (row.isNullAt(pos)) { + PrefixComparators.DOUBLE.NULL_PREFIX + } else { + PrefixComparators.DOUBLE.computePrefix(row.getDouble(pos)) + } } case _ => (row: InternalRow) => 0L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f3ef066528ff8..4ab2c41f1b339 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -340,8 +340,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - UnsafeExternalSort.supportsSchema(child.schema)) { - execution.UnsafeExternalSort(sortExprs, global, child) + TungstenSort.supportsSchema(child.schema)) { + execution.TungstenSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index f82208868c3e3..d0ad310062853 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -97,7 +97,7 @@ case class ExternalSort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ -case class UnsafeExternalSort( +case class TungstenSort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, @@ -110,7 +110,6 @@ case class UnsafeExternalSort( if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) @@ -149,7 +148,7 @@ case class UnsafeExternalSort( } @DeveloperApi -object UnsafeExternalSort { +object TungstenSort { /** * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 7b75f755918c1..c458f95ca1ab3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -31,7 +31,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 7a4baa9e4a49d..9cabc4b90bf8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -42,7 +42,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -53,7 +53,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { try { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -68,7 +68,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) @@ -88,11 +88,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + assert(TungstenSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, plan => ConvertToSafe( - UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) From 07fd7d36471dfb823c1ce3e3a18464043affde18 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 21:18:43 -0700 Subject: [PATCH 146/219] [SPARK-9460] Avoid byte array allocation in StringPrefixComparator. As of today, StringPrefixComparator converts the long values back to byte arrays in order to compare them. This patch optimizes this to compare the longs directly, rather than turning the longs into byte arrays and comparing them byte by byte (unsigned). This only works on little-endian architecture right now. Author: Reynold Xin Closes #7765 from rxin/SPARK-9460 and squashes the following commits: e4908cc [Reynold Xin] Stricter randomized tests. 4c8d094 [Reynold Xin] [SPARK-9460] Avoid byte array allocation in StringPrefixComparator. --- .../unsafe/sort/PrefixComparators.java | 29 ++----------------- .../unsafe/sort/PrefixComparatorsSuite.scala | 19 ++++++++---- .../apache/spark/unsafe/types/UTF8String.java | 9 ++++++ .../spark/unsafe/types/UTF8StringSuite.java | 11 +++++++ 4 files changed, 36 insertions(+), 32 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 5624e067da2cc..a9ee6042fec74 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -17,9 +17,7 @@ package org.apache.spark.util.collection.unsafe.sort; -import com.google.common.base.Charsets; -import com.google.common.primitives.Longs; -import com.google.common.primitives.UnsignedBytes; +import com.google.common.primitives.UnsignedLongs; import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; @@ -36,32 +34,11 @@ private PrefixComparators() {} public static final class StringPrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - // TODO: can done more efficiently - byte[] a = Longs.toByteArray(aPrefix); - byte[] b = Longs.toByteArray(bPrefix); - for (int i = 0; i < 8; i++) { - int c = UnsignedBytes.compare(a[i], b[i]); - if (c != 0) return c; - } - return 0; - } - - public long computePrefix(byte[] bytes) { - if (bytes == null) { - return 0L; - } else { - byte[] padded = new byte[8]; - System.arraycopy(bytes, 0, padded, 0, Math.min(bytes.length, 8)); - return Longs.fromByteArray(padded); - } - } - - public long computePrefix(String value) { - return value == null ? 0L : computePrefix(value.getBytes(Charsets.UTF_8)); + return UnsignedLongs.compare(aPrefix, bPrefix); } public long computePrefix(UTF8String value) { - return value == null ? 0L : computePrefix(value.getBytes()); + return value == null ? 0L : value.getPrefix(); } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 28fe9259453a6..26b7a9e816d1e 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -17,22 +17,29 @@ package org.apache.spark.util.collection.unsafe.sort +import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks - import org.apache.spark.SparkFunSuite +import org.apache.spark.unsafe.types.UTF8String class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { test("String prefix comparator") { def testPrefixComparison(s1: String, s2: String): Unit = { - val s1Prefix = PrefixComparators.STRING.computePrefix(s1) - val s2Prefix = PrefixComparators.STRING.computePrefix(s2) + val utf8string1 = UTF8String.fromString(s1) + val utf8string2 = UTF8String.fromString(s2) + val s1Prefix = PrefixComparators.STRING.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.STRING.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) + + val cmp = UnsignedBytes.lexicographicalComparator().compare( + utf8string1.getBytes.take(8), utf8string2.getBytes.take(8)) + assert( - (prefixComparisonResult == 0) || - (prefixComparisonResult < 0 && s1 < s2) || - (prefixComparisonResult > 0 && s1 > s2)) + (prefixComparisonResult == 0 && cmp == 0) || + (prefixComparisonResult < 0 && s1.compareTo(s2) < 0) || + (prefixComparisonResult > 0 && s1.compareTo(s2) > 0)) } // scalastyle:off diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3e1cc67dbf337..57522003ba2ba 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -137,6 +137,15 @@ public int numChars() { return len; } + /** + * Returns a 64-bit integer that can be used as the prefix used in sorting. + */ + public long getPrefix() { + long p = PlatformDependent.UNSAFE.getLong(base, offset); + p = java.lang.Long.reverseBytes(p); + return p; + } + /** * Returns the underline bytes, will be a copy of it if it's part of another array. */ diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index e2a5628ff4d93..42e09e435a412 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -63,8 +63,19 @@ public void emptyStringTest() { assertEquals(0, EMPTY_UTF8.numBytes()); } + @Test + public void prefix() { + assertTrue(fromString("a").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue(fromString("ab").getPrefix() - fromString("b").getPrefix() < 0); + assertTrue( + fromString("abbbbbbbbbbbasdf").getPrefix() - fromString("bbbbbbbbbbbbasdf").getPrefix() < 0); + assertTrue(fromString("").getPrefix() - fromString("a").getPrefix() < 0); + assertTrue(fromString("你好").getPrefix() - fromString("世界").getPrefix() > 0); + } + @Test public void compareTo() { + assertTrue(fromString("").compareTo(fromString("a")) < 0); assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0); assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0); From 27850af5255352cebd933ed3cc3d82c9ff6e9b62 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 21:24:47 -0700 Subject: [PATCH 147/219] [SPARK-9462][SQL] Initialize nondeterministic expressions in code gen fallback mode. Author: Reynold Xin Closes #7767 from rxin/SPARK-9462 and squashes the following commits: ef3e2d9 [Reynold Xin] Removed println 713ac3a [Reynold Xin] More unit tests. bb5c334 [Reynold Xin] [SPARK-9462][SQL] Initialize nondeterministic expressions in code gen fallback mode. --- .../expressions/codegen/CodegenFallback.scala | 7 ++- .../CodegenExpressionCachingSuite.scala | 46 +++++++++++++++++-- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6b187f05604fd..3492d2c6189ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.Expression trait CodegenFallback extends Expression { protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + foreach { + case n: Nondeterministic => n.setInitialValues() + case _ => + } + ctx.references += this val objectTerm = ctx.freshName("obj") s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 866bf904e4a4c..2d3f98dbbd3d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{BooleanType, DataType} /** @@ -27,7 +27,32 @@ import org.apache.spark.sql.types.{BooleanType, DataType} */ class CodegenExpressionCachingSuite extends SparkFunSuite { - test("GenerateUnsafeProjection") { + test("GenerateUnsafeProjection should initialize expressions") { + // Use an Add to wrap two of them together in case we only initialize the top level expressions. + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = UnsafeProjection.create(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateProjection.generate(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateMutableProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateMutableProjection.generate(Seq(expr))() + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GeneratePredicate should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GeneratePredicate.generate(expr) + assert(instance.apply(null) === false) + } + + test("GenerateUnsafeProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = UnsafeProjection.create(Seq(expr1)) assert(instance1.apply(null).getBoolean(0) === false) @@ -39,7 +64,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateProjection") { + test("GenerateProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateProjection.generate(Seq(expr1)) assert(instance1.apply(null).getBoolean(0) === false) @@ -51,7 +76,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateMutableProjection") { + test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateMutableProjection.generate(Seq(expr1))() assert(instance1.apply(null).getBoolean(0) === false) @@ -63,7 +88,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GeneratePredicate") { + test("GeneratePredicate should not share expression instances") { val expr1 = MutableExpression() val instance1 = GeneratePredicate.generate(expr1) assert(instance1.apply(null) === false) @@ -77,6 +102,17 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { } +/** + * An expression that's non-deterministic and doesn't support codegen. + */ +case class NondeterministicExpression() + extends LeafExpression with Nondeterministic with CodegenFallback { + override protected def initInternal(): Unit = { } + override protected def evalInternal(input: InternalRow): Any = false + override def nullable: Boolean = false + override def dataType: DataType = BooleanType +} + /** * An expression with mutable state so we can change it freely in our test suite. From f5dd11339fc9a6d11350f63beeca7c14aec169b1 Mon Sep 17 00:00:00 2001 From: Alex Angelini Date: Wed, 29 Jul 2015 22:25:38 -0700 Subject: [PATCH 148/219] Fix reference to self.names in StructType `names` is not defined in this context, I think you meant `self.names`. davies Author: Alex Angelini Closes #7766 from angelini/fix_struct_type_names and squashes the following commits: 01543a1 [Alex Angelini] Fix reference to self.names in StructType --- python/pyspark/sql/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b97d50c945f24..8859308d66027 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -531,7 +531,7 @@ def toInternal(self, obj): if self._needSerializeFields: if isinstance(obj, dict): - return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields)) + return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) else: From e044705b4402f86d0557ecd146f3565388c7eeb4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 29 Jul 2015 22:30:49 -0700 Subject: [PATCH 149/219] [SPARK-9116] [SQL] [PYSPARK] support Python only UDT in __main__ Also we could create a Python UDT without having a Scala one, it's important for Python users. cc mengxr JoshRosen Author: Davies Liu Closes #7453 from davies/class_in_main and squashes the following commits: 4dfd5e1 [Davies Liu] add tests for Python and Scala UDT 793d9b2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main dc65f19 [Davies Liu] address comment a9a3c40 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main a86e1fc [Davies Liu] fix serialization ad528ba [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main 63f52ef [Davies Liu] fix pylint check 655b8a9 [Davies Liu] Merge branch 'master' of github.com:apache/spark into class_in_main 316a394 [Davies Liu] support Python UDT with UTF 0bcb3ef [Davies Liu] fix bug in mllib de986d6 [Davies Liu] fix test 83d65ac [Davies Liu] fix bug in StructType 55bb86e [Davies Liu] support Python UDT in __main__ (without Scala one) --- pylintrc | 2 +- python/pyspark/cloudpickle.py | 38 +++++- python/pyspark/shuffle.py | 2 +- python/pyspark/sql/context.py | 108 ++++++++++------- python/pyspark/sql/tests.py | 112 ++++++++++++++++-- python/pyspark/sql/types.py | 78 ++++++------ .../org/apache/spark/sql/types/DataType.scala | 9 ++ .../spark/sql/types/UserDefinedType.scala | 29 +++++ .../spark/sql/execution/pythonUDFs.scala | 1 - 9 files changed, 286 insertions(+), 93 deletions(-) diff --git a/pylintrc b/pylintrc index 061775960393b..6a675770da69a 100644 --- a/pylintrc +++ b/pylintrc @@ -84,7 +84,7 @@ enable= # If you would like to improve the code quality of pyspark, remove any of these disabled errors # run ./dev/lint-python and see if the errors raised by pylint can be fixed. -disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable +disable=invalid-name,missing-docstring,protected-access,unused-argument,no-member,unused-wildcard-import,redefined-builtin,too-many-arguments,unused-variable,too-few-public-methods,bad-continuation,duplicate-code,redefined-outer-name,too-many-ancestors,import-error,superfluous-parens,unused-import,line-too-long,no-name-in-module,unnecessary-lambda,import-self,no-self-use,unidiomatic-typecheck,fixme,too-many-locals,cyclic-import,too-many-branches,bare-except,wildcard-import,dangerous-default-value,broad-except,too-many-public-methods,deprecated-lambda,anomalous-backslash-in-string,too-many-lines,reimported,too-many-statements,bad-whitespace,unpacking-non-sequence,too-many-instance-attributes,abstract-method,old-style-class,global-statement,attribute-defined-outside-init,arguments-differ,undefined-all-variable,no-init,useless-else-on-loop,super-init-not-called,notimplemented-raised,too-many-return-statements,pointless-string-statement,global-variable-undefined,bad-classmethod-argument,too-many-format-args,parse-error,no-self-argument,pointless-statement,undefined-variable,undefined-loop-variable [REPORTS] diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 9ef93071d2e77..3b647985801b7 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -350,7 +350,26 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ - self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj) + self.save(_load_class) + self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) + d.pop('__doc__', None) + # handle property and staticmethod + dd = {} + for k, v in d.items(): + if isinstance(v, property): + k = ('property', k) + v = (v.fget, v.fset, v.fdel, v.__doc__) + elif isinstance(v, staticmethod) and hasattr(v, '__func__'): + k = ('staticmethod', k) + v = v.__func__ + elif isinstance(v, classmethod) and hasattr(v, '__func__'): + k = ('classmethod', k) + v = v.__func__ + dd[k] = v + self.save(dd) + self.write(pickle.TUPLE2) + self.write(pickle.REDUCE) + else: raise pickle.PicklingError("Can't pickle %r" % obj) @@ -708,6 +727,23 @@ def _make_skel_func(code, closures, base_globals = None): None, None, closure) +def _load_class(cls, d): + """ + Loads additional properties into class `cls`. + """ + for k, v in d.items(): + if isinstance(k, tuple): + typ, k = k + if typ == 'property': + v = property(*v) + elif typ == 'staticmethod': + v = staticmethod(v) + elif typ == 'classmethod': + v = classmethod(v) + setattr(cls, k, v) + return cls + + """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 8fb71bac64a5e..b8118bdb7ca76 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -606,7 +606,7 @@ def _open_file(self): if not os.path.exists(d): os.makedirs(d) p = os.path.join(d, str(id(self))) - self._file = open(p, "wb+", 65536) + self._file = open(p, "w+b", 65536) self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) os.unlink(p) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index abb6522dde7b0..917de24f3536b 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -277,6 +277,66 @@ def applySchema(self, rdd, schema): return self.createDataFrame(rdd, schema) + def _createFromRDD(self, rdd, schema, samplingRatio): + """ + Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. + """ + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchema(rdd, samplingRatio) + converter = _create_converter(struct) + rdd = rdd.map(converter) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + # take the first few rows to verify schema + rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + rdd = rdd.map(schema.toInternal) + return rdd, schema + + def _createFromLocal(self, data, schema): + """ + Create an RDD for DataFrame from an list or pandas.DataFrame, returns + the RDD and schema. + """ + if has_pandas and isinstance(data, pandas.DataFrame): + if schema is None: + schema = [str(x) for x in data.columns] + data = [r.tolist() for r in data.to_records(index=False)] + + # make sure data could consumed multiple times + if not isinstance(data, list): + data = list(data) + + if schema is None or isinstance(schema, (list, tuple)): + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + struct.names[i] = name + schema = struct + + elif isinstance(schema, StructType): + for row in data: + _verify_type(row, schema) + + else: + raise TypeError("schema should be StructType or list or None, but got: %s" % schema) + + # convert python objects to sql data + data = [schema.toInternal(row) for row in data] + return self._sc.parallelize(data), schema + @since(1.3) @ignore_unicode_prefix def createDataFrame(self, data, schema=None, samplingRatio=None): @@ -340,49 +400,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") - if has_pandas and isinstance(data, pandas.DataFrame): - if schema is None: - schema = [str(x) for x in data.columns] - data = [r.tolist() for r in data.to_records(index=False)] - - if not isinstance(data, RDD): - if not isinstance(data, list): - data = list(data) - try: - # data could be list, tuple, generator ... - rdd = self._sc.parallelize(data) - except Exception: - raise TypeError("cannot create an RDD from type: %s" % type(data)) + if isinstance(data, RDD): + rdd, schema = self._createFromRDD(data, schema, samplingRatio) else: - rdd = data - - if schema is None or isinstance(schema, (list, tuple)): - if isinstance(data, RDD): - struct = self._inferSchema(rdd, samplingRatio) - else: - struct = self._inferSchemaFromList(data) - if isinstance(schema, (list, tuple)): - for i, name in enumerate(schema): - struct.fields[i].name = name - schema = struct - converter = _create_converter(schema) - rdd = rdd.map(converter) - - elif isinstance(schema, StructType): - # take the first few rows to verify schema - rows = rdd.take(10) - for row in rows: - _verify_type(row, schema) - - else: - raise TypeError("schema should be StructType or list or None") - - # convert python objects to sql data - rdd = rdd.map(schema.toInternal) - + rdd, schema = self._createFromLocal(data, schema) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return DataFrame(df, self) + jdf = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + df = DataFrame(jdf, self) + df._schema = schema + return df @since(1.3) def registerDataFrameAsTable(self, df, tableName): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5aa6135dc1ee7..ebd3ea8db6a43 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -75,7 +75,7 @@ def sqlType(self): @classmethod def module(cls): - return 'pyspark.tests' + return 'pyspark.sql.tests' @classmethod def scalaUDT(cls): @@ -106,10 +106,45 @@ def __str__(self): return "(%s,%s)" % (self.x, self.y) def __eq__(self, other): - return isinstance(other, ExamplePoint) and \ + return isinstance(other, self.__class__) and \ other.x == self.x and other.y == self.y +class PythonOnlyUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return '__main__' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return PythonOnlyPoint(datum[0], datum[1]) + + @staticmethod + def foo(): + pass + + @property + def props(self): + return {} + + +class PythonOnlyPoint(ExamplePoint): + """ + An example class to demonstrate UDT in only Python + """ + __UDT__ = PythonOnlyUDT() + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -395,10 +430,39 @@ def test_convert_row_to_dict(self): self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) + def test_udt(self): + from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _verify_type + from pyspark.sql.tests import ExamplePointUDT, ExamplePoint + + def check_datatype(datatype): + pickled = pickle.loads(pickle.dumps(datatype)) + assert datatype == pickled + scala_datatype = self.sqlCtx._ssql_ctx.parseDataType(datatype.json()) + python_datatype = _parse_datatype_json_string(scala_datatype.json()) + assert datatype == python_datatype + + check_datatype(ExamplePointUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + check_datatype(structtype_with_udt) + p = ExamplePoint(1.0, 2.0) + self.assertEqual(_infer_type(p), ExamplePointUDT()) + _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], ExamplePointUDT())) + + check_datatype(PythonOnlyUDT()) + structtype_with_udt = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + check_datatype(structtype_with_udt) + p = PythonOnlyPoint(1.0, 2.0) + self.assertEqual(_infer_type(p), PythonOnlyUDT()) + _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) + self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) schema = df.schema field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) @@ -406,36 +470,66 @@ def test_infer_schema_with_udt(self): point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + schema = df.schema + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), PythonOnlyUDT) + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point + self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) + def test_apply_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = (1.0, ExamplePoint(1.0, 2.0)) - rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - df = rdd.toDF(schema) + df = self.sqlCtx.createDataFrame([row], schema) point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = (1.0, PythonOnlyPoint(1.0, 2.0)) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", PythonOnlyUDT(), False)]) + df = self.sqlCtx.createDataFrame([row], schema) + point = df.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_udf_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df = self.sc.parallelize([row]).toDF() + df = self.sqlCtx.createDataFrame([row]) self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df = self.sqlCtx.createDataFrame([row]) + self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) + udf = UserDefinedFunction(lambda p: p.y, DoubleType()) + self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT()) + self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) + def test_parquet_with_udt(self): - from pyspark.sql.tests import ExamplePoint + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) - df0 = self.sc.parallelize([row]).toDF() + df0 = self.sqlCtx.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") - df0.saveAsParquetFile(output_dir) + df0.write.parquet(output_dir) df1 = self.sqlCtx.parquetFile(output_dir) point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) + df0 = self.sqlCtx.createDataFrame([row]) + df0.write.parquet(output_dir, mode='overwrite') + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point + self.assertEquals(point, PythonOnlyPoint(1.0, 2.0)) + def test_column_operators(self): ci = self.df.key cs = self.df.value diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8859308d66027..0976aea72c034 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -22,6 +22,7 @@ import calendar import json import re +import base64 from array import array if sys.version >= "3": @@ -31,6 +32,8 @@ from py4j.protocol import register_input_converter from py4j.java_gateway import JavaClass +from pyspark.serializers import CloudPickleSerializer + __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", @@ -458,7 +461,7 @@ def __init__(self, fields=None): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeFields = None + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -501,6 +504,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) + self._needSerializeAnyField = any(f.needConversion() for f in self.fields) return self def simpleString(self): @@ -526,10 +530,7 @@ def toInternal(self, obj): if obj is None: return - if self._needSerializeFields is None: - self._needSerializeFields = any(f.needConversion() for f in self.fields) - - if self._needSerializeFields: + if self._needSerializeAnyField: if isinstance(obj, dict): return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): @@ -550,7 +551,10 @@ def fromInternal(self, obj): if isinstance(obj, Row): # it's already converted by pickler return obj - values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)] + if self._needSerializeAnyField: + values = [f.fromInternal(v) for f, v in zip(self.fields, obj)] + else: + values = obj return _create_row(self.names, values) @@ -581,9 +585,10 @@ def module(cls): @classmethod def scalaUDT(cls): """ - The class name of the paired Scala UDT. + The class name of the paired Scala UDT (could be '', if there + is no corresponding one). """ - raise NotImplementedError("UDT must have a paired Scala UDT.") + return '' def needConversion(self): return True @@ -622,22 +627,37 @@ def json(self): return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) def jsonValue(self): - schema = { - "type": "udt", - "class": self.scalaUDT(), - "pyClass": "%s.%s" % (self.module(), type(self).__name__), - "sqlType": self.sqlType().jsonValue() - } + if self.scalaUDT(): + assert self.module() != '__main__', 'UDT in __main__ cannot work with ScalaUDT' + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + else: + ser = CloudPickleSerializer() + b = ser.dumps(type(self)) + schema = { + "type": "udt", + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "serializedClass": base64.b64encode(b).decode('utf8'), + "sqlType": self.sqlType().jsonValue() + } return schema @classmethod def fromJson(cls, json): - pyUDT = json["pyClass"] + pyUDT = str(json["pyClass"]) split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] m = __import__(pyModule, globals(), locals(), [pyClass]) - UDT = getattr(m, pyClass) + if not hasattr(m, pyClass): + s = base64.b64decode(json['serializedClass'].encode('utf-8')) + UDT = CloudPickleSerializer().loads(s) + else: + UDT = getattr(m, pyClass) return UDT() def __eq__(self, other): @@ -696,11 +716,6 @@ def _parse_datatype_json_string(json_string): >>> complex_maptype = MapType(complex_structtype, ... complex_arraytype, False) >>> check_datatype(complex_maptype) - - >>> check_datatype(ExamplePointUDT()) - >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> check_datatype(structtype_with_udt) """ return _parse_datatype_json_value(json.loads(json_string)) @@ -752,10 +767,6 @@ def _parse_datatype_json_value(json_value): def _infer_type(obj): """Infer the DataType from obj - - >>> p = ExamplePoint(1.0, 2.0) - >>> _infer_type(p) - ExamplePointUDT """ if obj is None: return NullType() @@ -1090,11 +1101,6 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... - >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) - >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError:... """ # all objects are nullable if obj is None: @@ -1259,18 +1265,12 @@ def convert(self, obj, gateway_client): def _test(): import doctest from pyspark.context import SparkContext - # let doctest run in pyspark.sql.types, so DataTypes can be picklable - import pyspark.sql.types - from pyspark.sql import Row, SQLContext - from pyspark.sql.tests import ExamplePoint, ExamplePointUDT - globs = pyspark.sql.types.__dict__.copy() + from pyspark.sql import SQLContext + globs = globals() sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) - globs['ExamplePoint'] = ExamplePoint - globs['ExamplePointUDT'] = ExamplePointUDT - (failure_count, test_count) = doctest.testmod( - pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 591fb26e67c4a..f4428c2e8b202 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -142,12 +142,21 @@ object DataType { ("type", JString("struct"))) => StructType(fields.map(parseStructField)) + // Scala/Java UDT case JSortedObject( ("class", JString(udtClass)), ("pyClass", _), ("sqlType", _), ("type", JString("udt"))) => Utils.classForName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] + + // Python UDT + case JSortedObject( + ("pyClass", JString(pyClass)), + ("serializedClass", JString(serialized)), + ("sqlType", v: JValue), + ("type", JString("udt"))) => + new PythonUserDefinedType(parseDataType(v), pyClass, serialized) } private def parseStructField(json: JValue): StructField = json match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index e47cfb4833bd8..4305903616bd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -45,6 +45,9 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { /** Paired Python UDT class, if exists. */ def pyUDT: String = null + /** Serialized Python UDT class, if exists. */ + def serializedPyClass: String = null + /** * Convert the user type to a SQL datum * @@ -82,3 +85,29 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass } + +/** + * ::DeveloperApi:: + * The user defined type in Python. + * + * Note: This can only be accessed via Python UDF, or accessed as serialized object. + */ +private[sql] class PythonUserDefinedType( + val sqlType: DataType, + override val pyUDT: String, + override val serializedPyClass: String) extends UserDefinedType[Any] { + + /* The serialization is handled by UDT class in Python */ + override def serialize(obj: Any): Any = obj + override def deserialize(datam: Any): Any = datam + + /* There is no Java class for Python UDT */ + override def userClass: java.lang.Class[Any] = null + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("pyClass" -> pyUDT) ~ + ("serializedClass" -> serializedPyClass) ~ + ("sqlType" -> sqlType.jsonValue) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index ec084a299649e..3c38916fd7504 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -267,7 +267,6 @@ object EvaluatePython { pickler.save(row.values(i)) i += 1 } - row.values.foreach(pickler.save) out.write(Opcodes.TUPLE) out.write(Opcodes.REDUCE) } From 712465b68e50df7a2050b27528acda9f0d95ba1f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 29 Jul 2015 22:51:06 -0700 Subject: [PATCH 150/219] HOTFIX: disable HashedRelationSuite. --- .../spark/sql/execution/joins/HashedRelationSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 8b1a9b21a96b9..941f6d4f6a450 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -33,7 +33,7 @@ class HashedRelationSuite extends SparkFunSuite { override def apply(row: InternalRow): InternalRow = row } - test("GeneralHashedRelation") { + ignore("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) @@ -47,7 +47,7 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed.get(data(2)) === data2) } - test("UniqueKeyHashedRelation") { + ignore("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) @@ -64,7 +64,7 @@ class HashedRelationSuite extends SparkFunSuite { assert(uniqHashed.getValue(InternalRow(10)) === null) } - test("UnsafeHashedRelation") { + ignore("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) From e127ec34d58ceb0a9d45748c2f2918786ba0a83d Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Wed, 29 Jul 2015 23:24:20 -0700 Subject: [PATCH 151/219] [SPARK-9428] [SQL] Add test cases for null inputs for expression unit tests JIRA: https://issues.apache.org/jira/browse/SPARK-9428 Author: Yijie Shen Closes #7748 from yjshen/string_cleanup and squashes the following commits: e0c2b3d [Yijie Shen] update codegen in RegExpExtract and RegExpReplace 26614d2 [Yijie Shen] MathFunctionSuite a402859 [Yijie Shen] complex_create, conditional and cast 6e4e608 [Yijie Shen] arithmetic and cast 52593c1 [Yijie Shen] null input test cases for StringExpressionSuite --- .../spark/sql/catalyst/expressions/Cast.scala | 12 ++-- .../expressions/complexTypeCreator.scala | 16 +++-- .../catalyst/expressions/conditionals.scala | 10 +-- .../spark/sql/catalyst/expressions/math.scala | 14 ++--- .../expressions/stringOperations.scala | 11 ++-- .../ExpressionTypeCheckingSuite.scala | 7 ++- .../ArithmeticExpressionSuite.scala | 3 + .../sql/catalyst/expressions/CastSuite.scala | 52 ++++++++++++++- .../expressions/ComplexTypeSuite.scala | 23 +++---- .../ConditionalExpressionSuite.scala | 4 ++ .../expressions/MathFunctionsSuite.scala | 63 ++++++++++--------- .../catalyst/expressions/RandomSuite.scala | 1 - .../expressions/StringExpressionsSuite.scala | 26 ++++++++ .../org/apache/spark/sql/functions.scala | 6 +- 14 files changed, 167 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c6e8af27667ee..8c01c13c9ccd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -599,7 +599,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" case _: IntegralType => (c, evPrim, evNull) => s"$evPrim = ${longToTimeStampCode(c)};" case DateType => @@ -665,7 +665,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -687,7 +687,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -731,7 +731,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -753,7 +753,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1.0f : 0.0f;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => @@ -775,7 +775,7 @@ case class Cast(child: Expression, dataType: DataType) } """ case BooleanType => - (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" + (c, evPrim, evNull) => s"$evPrim = $c ? 1.0d : 0.0d;" case DateType => (c, evPrim, evNull) => s"$evNull = true;" case TimestampType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d8c9087ff5380..0517050a45109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.unsafe.types.UTF8String + import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow @@ -127,11 +129,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip - private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) + private lazy val names = nameExprs.map(_.eval(EmptyRow)) override lazy val dataType: StructType = { val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } @@ -144,14 +147,15 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { if (children.size % 2 != 0) { TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.") } else { - val invalidNames = - nameExprs.filterNot(e => e.foldable && e.dataType == StringType && !nullable) + val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType) if (invalidNames.nonEmpty) { TypeCheckResult.TypeCheckFailure( - s"Odd position only allow foldable and not-null StringType expressions, got :" + + s"Only foldable StringType expressions are allowed to appear at odd position , got :" + s" ${invalidNames.mkString(",")}") - } else { + } else if (names.forall(_ != null)){ TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure("Field name should not be null") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 15b33da884dcb..961b1d8616801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -315,7 +315,6 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW * It takes at least 2 parameters, and returns null iff all parameters are null. */ case class Least(children: Seq[Expression]) extends Expression { - require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -323,7 +322,9 @@ case class Least(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got LEAST (${children.map(_.dataType)}).") @@ -369,7 +370,6 @@ case class Least(children: Seq[Expression]) extends Expression { * It takes at least 2 parameters, and returns null iff all parameters are null. */ case class Greatest(children: Seq[Expression]) extends Expression { - require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length) override def nullable: Boolean = children.forall(_.nullable) override def foldable: Boolean = children.forall(_.foldable) @@ -377,7 +377,9 @@ case class Greatest(children: Seq[Expression]) extends Expression { private lazy val ordering = TypeUtils.getOrdering(dataType) override def checkInputDataTypes(): TypeCheckResult = { - if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + if (children.length <= 1) { + TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments") + } else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { TypeCheckResult.TypeCheckFailure( s"The expressions should all have the same type," + s" got GREATEST (${children.map(_.dataType)}).") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 68cca0ad3d067..e6d807f6d897b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -646,19 +646,19 @@ case class Logarithm(left: Expression, right: Expression) /** * Round the `child`'s result to `scale` decimal place when `scale` >= 0 * or round at integral part when `scale` < 0. - * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30. + * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30. * - * Child of IntegralType would eval to itself when `scale` >= 0. - * Child of FractionalType whose value is NaN or Infinite would always eval to itself. + * Child of IntegralType would round to itself when `scale` >= 0. + * Child of FractionalType whose value is NaN or Infinite would always round to itself. * - * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed]], - * which leads to scale update in DecimalType's [[PrecisionInfo]] + * Round's dataType would always equal to `child`'s dataType except for DecimalType, + * which would lead scale decrease from the origin DecimalType. * * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime */ case class Round(child: Expression, scale: Expression) - extends BinaryExpression with ExpectsInputTypes { + extends BinaryExpression with ImplicitCastInputTypes { import BigDecimal.RoundingMode.HALF_UP @@ -838,6 +838,4 @@ case class Round(child: Expression, scale: Expression) """ } } - - override def prettyName: String = "round" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 6db4e19c24ed5..5b3a64a09679c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -22,7 +22,6 @@ import java.util.Locale import java.util.regex.{MatchResult, Pattern} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -52,7 +51,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + s"${eval.isNull} ? null : ${eval.primitive}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; @@ -1008,7 +1007,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio s""" ${evalSubject.code} - boolean ${ev.isNull} = ${evalSubject.isNull}; + boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${evalSubject.isNull}) { ${evalRegexp.code} @@ -1103,9 +1102,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val evalIdx = idx.gen(ctx) s""" - ${ctx.javaType(dataType)} ${ev.primitive} = null; - boolean ${ev.isNull} = true; ${evalSubject.code} + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + boolean ${ev.isNull} = true; if (!${evalSubject.isNull}) { ${evalRegexp.code} if (!${evalRegexp.isNull}) { @@ -1117,7 +1116,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); + ${termPattern}.matcher(${evalSubject.primitive}.toString()); if (m.find()) { ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8acd4c685e2bc..a52e4cb4dfd9f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -167,10 +167,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments") assertError( CreateNamedStruct(Seq(1, "a", "b", 2.0)), - "Odd position only allow foldable and not-null StringType expressions") + "Only foldable StringType expressions are allowed to appear at odd position") assertError( CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)), - "Odd position only allow foldable and not-null StringType expressions") + "Only foldable StringType expressions are allowed to appear at odd position") + assertError( + CreateNamedStruct(Seq(Literal.create(null, StringType), "a")), + "Field name should not be null") } test("check types for ROUND") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 7773e098e0caa..d03b0fbbfb2b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -116,9 +116,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("Abs") { testNumericDataTypes { convert => + val input = Literal(convert(1)) + val dataType = input.dataType checkEvaluation(Abs(Literal(convert(0))), convert(0)) checkEvaluation(Abs(Literal(convert(1))), convert(1)) checkEvaluation(Abs(Literal(convert(-1))), convert(1)) + checkEvaluation(Abs(Literal.create(null, dataType)), null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 0e0213be0f57b..a517da9872852 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -43,6 +43,42 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(v, Literal(expected).dataType), expected) } + private def checkNullCast(from: DataType, to: DataType): Unit = { + checkEvaluation(Cast(Literal.create(null, from), to), null) + } + + test("null cast") { + import DataTypeTestUtils._ + + // follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic + // to ensure we test every possible cast situation here + atomicTypes.zip(atomicTypes).foreach { case (from, to) => + checkNullCast(from, to) + } + + atomicTypes.foreach(dt => checkNullCast(NullType, dt)) + atomicTypes.foreach(dt => checkNullCast(dt, StringType)) + checkNullCast(StringType, BinaryType) + checkNullCast(StringType, BooleanType) + checkNullCast(DateType, BooleanType) + checkNullCast(TimestampType, BooleanType) + numericTypes.foreach(dt => checkNullCast(dt, BooleanType)) + + checkNullCast(StringType, TimestampType) + checkNullCast(BooleanType, TimestampType) + checkNullCast(DateType, TimestampType) + numericTypes.foreach(dt => checkNullCast(dt, TimestampType)) + + atomicTypes.foreach(dt => checkNullCast(dt, DateType)) + + checkNullCast(StringType, CalendarIntervalType) + numericTypes.foreach(dt => checkNullCast(StringType, dt)) + numericTypes.foreach(dt => checkNullCast(BooleanType, dt)) + numericTypes.foreach(dt => checkNullCast(DateType, dt)) + numericTypes.foreach(dt => checkNullCast(TimestampType, dt)) + for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to) + } + test("cast string to date") { var c = Calendar.getInstance() c.set(2015, 0, 1, 0, 0, 0) @@ -69,8 +105,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast string to timestamp") { - checkEvaluation(Cast(Literal("123"), TimestampType), - null) + checkEvaluation(Cast(Literal("123"), TimestampType), null) var c = Calendar.getInstance() c.set(2015, 0, 1, 0, 0, 0) @@ -473,6 +508,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val array_notNull = Literal.create(Seq("123", "abc", ""), ArrayType(StringType, containsNull = false)) + checkNullCast(ArrayType(StringType), ArrayType(IntegerType)) + { val ret = cast(array, ArrayType(IntegerType, containsNull = true)) assert(ret.resolved === true) @@ -526,6 +563,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { Map("a" -> "123", "b" -> "abc", "c" -> ""), MapType(StringType, StringType, valueContainsNull = false)) + checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType)) + { val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true)) assert(ret.resolved === true) @@ -580,6 +619,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("cast from struct") { + checkNullCast( + StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType))), + StructType(Seq( + StructField("a", StringType), + StructField("b", StringType)))) + val struct = Literal.create( InternalRow( UTF8String.fromString("123"), @@ -728,5 +775,4 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType), "interval 1 years 3 months -3 days") } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index fc842772f3480..5de5ddce975d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -132,6 +132,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(CreateArray(intWithNull), intSeq :+ null, EmptyRow) checkEvaluation(CreateArray(longWithNull), longSeq :+ null, EmptyRow) checkEvaluation(CreateArray(strWithNull), strSeq :+ null, EmptyRow) + checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil) } test("CreateStruct") { @@ -139,26 +140,20 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val c1 = 'a.int.at(0) val c3 = 'c.int.at(2) checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) + checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null)) } test("CreateNamedStruct") { - val row = InternalRow(1, 2, 3) + val row = create_row(1, 2, 3) val c1 = 'a.int.at(0) val c3 = 'c.int.at(2) - checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), InternalRow(1, 3), row) - } - - test("CreateNamedStruct with literal field") { - val row = InternalRow(1, 2, 3) - val c1 = 'a.int.at(0) + checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row) checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), - InternalRow(1, UTF8String.fromString("y")), row) - } - - test("CreateNamedStruct from all literal fields") { - checkEvaluation( - CreateNamedStruct(Seq("a", "x", "b", 2.0)), - InternalRow(UTF8String.fromString("x"), 2.0), InternalRow.empty) + create_row(1, UTF8String.fromString("y")), row) + checkEvaluation(CreateNamedStruct(Seq("a", "x", "b", 2.0)), + create_row(UTF8String.fromString("x"), 2.0)) + checkEvaluation(CreateNamedStruct(Seq("a", Literal.create(null, IntegerType))), + create_row(null)) } test("test dsl for complex type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index b31d6661c8c1c..d26bcdb2902ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -149,6 +149,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Least(Seq(c1, c2, Literal(-1))), -1, row) checkEvaluation(Least(Seq(c4, c5, c3, c3, Literal("a"))), "a", row) + val nullLiteral = Literal.create(null, IntegerType) + checkEvaluation(Least(Seq(nullLiteral, nullLiteral)), null) checkEvaluation(Least(Seq(Literal(null), Literal(null))), null, InternalRow.empty) checkEvaluation(Least(Seq(Literal(-1.0), Literal(2.5))), -1.0, InternalRow.empty) checkEvaluation(Least(Seq(Literal(-1), Literal(2))), -1, InternalRow.empty) @@ -188,6 +190,8 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Greatest(Seq(c1, c2, Literal(2))), 2, row) checkEvaluation(Greatest(Seq(c4, c5, c3, Literal("ccc"))), "ccc", row) + val nullLiteral = Literal.create(null, IntegerType) + checkEvaluation(Greatest(Seq(nullLiteral, nullLiteral)), null) checkEvaluation(Greatest(Seq(Literal(null), Literal(null))), null, InternalRow.empty) checkEvaluation(Greatest(Seq(Literal(-1.0), Literal(2.5))), 2.5, InternalRow.empty) checkEvaluation(Greatest(Seq(Literal(-1), Literal(2))), 2, InternalRow.empty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 21459a7c69838..9fcb548af6bbb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -110,35 +110,17 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(c(Literal(1.0), Literal.create(null, DoubleType)), null, create_row(null)) } - test("conv") { - checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") - checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") - checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") - checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) - checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) - checkEvaluation( - Conv(Literal("1234"), Literal(10), Literal(37)), null) - checkEvaluation( - Conv(Literal(""), Literal(10), Literal(16)), null) - checkEvaluation( - Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") - // If there is an invalid digit in the number, the longest valid prefix should be converted. - checkEvaluation( - Conv(Literal("11abc"), Literal(10), Literal(16)), "B") - } - private def checkNaN( - expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { checkNaNWithoutCodegen(expression, inputRow) checkNaNWithGeneratedProjection(expression, inputRow) checkNaNWithOptimization(expression, inputRow) } private def checkNaNWithoutCodegen( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } @@ -149,7 +131,6 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - private def checkNaNWithGeneratedProjection( expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { @@ -172,6 +153,25 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } + test("conv") { + checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") + checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") + checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") + checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null) + checkEvaluation( + Conv(Literal("1234"), Literal(10), Literal(37)), null) + checkEvaluation( + Conv(Literal(""), Literal(10), Literal(16)), null) + checkEvaluation( + Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF") + // If there is an invalid digit in the number, the longest valid prefix should be converted. + checkEvaluation( + Conv(Literal("11abc"), Literal(10), Literal(16)), "B") + } + test("e") { testLeaf(EulerNumber, math.E) } @@ -417,7 +417,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("round") { - val domain = -6 to 6 + val scales = -6 to 6 val doublePi: Double = math.Pi val shortPi: Short = 31415 val intPi: Int = 314159265 @@ -437,17 +437,16 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ Seq.fill(7)(31415926535897932L) - val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), - BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), - BigDecimal(3.141593), BigDecimal(3.1415927)) - - domain.zipWithIndex.foreach { case (scale, i) => + scales.zipWithIndex.foreach { case (scale, i) => checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) } + val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), + BigDecimal(3.142), BigDecimal(3.1416), BigDecimal(3.14159), + BigDecimal(3.141593), BigDecimal(3.1415927)) // round_scale > current_scale would result in precision increase // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null (0 to 7).foreach { i => @@ -456,5 +455,11 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) } + + DataTypeTestUtils.numericTypes.foreach { dataType => + checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null) + checkEvaluation(Round(Literal.create(null, dataType), + Literal.create(null, IntegerType)), null) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 5db992654811a..4a644d136f09c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -21,7 +21,6 @@ import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite - class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3d294fda5d103..07b952531ec2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -348,6 +348,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringTrimLeft(s), "花花世界 ", create_row(" 花花世界 ")) checkEvaluation(StringTrim(s), "花花世界", create_row(" 花花世界 ")) // scalastyle:on + checkEvaluation(StringTrim(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimLeft(Literal.create(null, StringType)), null) + checkEvaluation(StringTrimRight(Literal.create(null, StringType)), null) } test("FORMAT") { @@ -391,6 +394,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s3 = 'c.string.at(2) val s4 = 'd.int.at(3) val row1 = create_row("aaads", "aa", "zz", 1) + val row2 = create_row(null, "aa", "zz", 0) + val row3 = create_row("aaads", null, "zz", 0) + val row4 = create_row(null, null, null, 0) checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1) checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 2, row1) @@ -402,6 +408,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringLocate(s2, s1, s4), 2, row1) checkEvaluation(new StringLocate(s3, s1), 0, row1) checkEvaluation(StringLocate(s3, s1, Literal.create(null, IntegerType)), 0, row1) + checkEvaluation(new StringLocate(s2, s1), null, row2) + checkEvaluation(new StringLocate(s2, s1), null, row3) + checkEvaluation(new StringLocate(s2, s1, Literal.create(null, IntegerType)), 0, row4) } test("LPAD/RPAD") { @@ -448,6 +457,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("abccc") checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1) checkEvaluation(StringReverse(s), "cccba", row1) + checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1) } test("SPACE") { @@ -466,6 +476,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("100-200", "(\\d+)", "num") val row2 = create_row("100-200", "(\\d+)", "###") val row3 = create_row("100-200", "(-)", "###") + val row4 = create_row(null, "(\\d+)", "###") + val row5 = create_row("100-200", null, "###") + val row6 = create_row("100-200", "(-)", null) val s = 's.string.at(0) val p = 'p.string.at(1) @@ -475,6 +488,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, "num-num", row1) checkEvaluation(expr, "###-###", row2) checkEvaluation(expr, "100###200", row3) + checkEvaluation(expr, null, row4) + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) } test("RegexExtract") { @@ -482,6 +498,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) val row3 = create_row("100-200", "(\\d+).*", 1) val row4 = create_row("100-200", "([a-z])", 1) + val row5 = create_row(null, "([a-z])", 1) + val row6 = create_row("100-200", null, 1) + val row7 = create_row("100-200", "([a-z])", null) val s = 's.string.at(0) val p = 'p.string.at(1) @@ -492,6 +511,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(expr, "200", row2) checkEvaluation(expr, "100", row3) checkEvaluation(expr, "", row4) // will not match anything, empty string get + checkEvaluation(expr, null, row5) + checkEvaluation(expr, null, row6) + checkEvaluation(expr, null, row7) val expr1 = new RegExpExtract(s, p) checkEvaluation(expr1, "100", row1) @@ -501,11 +523,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1) val row1 = create_row("aa2bb3cc", "[1-9]+") + val row2 = create_row(null, "[1-9]+") + val row3 = create_row("aa2bb3cc", null) checkEvaluation( StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) checkEvaluation( StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) + checkEvaluation(StringSplit(s1, s2), null, row2) + checkEvaluation(StringSplit(s1, s2), null, row3) } test("length for string / binary") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4261a5e7cbeb5..4e68a88e7cda6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1423,7 +1423,8 @@ object functions { def round(columnName: String): Column = round(Column(columnName), 0) /** - * Returns the value of `e` rounded to `scale` decimal places. + * Round the value of `e` to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. * * @group math_funcs * @since 1.5.0 @@ -1431,7 +1432,8 @@ object functions { def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) /** - * Returns the value of the given column rounded to `scale` decimal places. + * Round the value of the given column to `scale` decimal places if `scale` >= 0 + * or at integral part when `scale` < 0. * * @group math_funcs * @since 1.5.0 From 1221849f91739454b8e495889cba7498ba8beea7 Mon Sep 17 00:00:00 2001 From: Joseph Batchik Date: Wed, 29 Jul 2015 23:35:55 -0700 Subject: [PATCH 152/219] [SPARK-8005][SQL] Input file name Users can now get the file name of the partition being read in. A thread local variable is in `SQLNewHadoopRDD` and is set when the partition is computed. `SQLNewHadoopRDD` is moved to core so that the catalyst package can reach it. This supports: `df.select(inputFileName())` and `sqlContext.sql("select input_file_name() from table")` Author: Joseph Batchik Closes #7743 from JDrit/input_file_name and squashes the following commits: abb8609 [Joseph Batchik] fixed failing test and changed the default value to be an empty string d2f323d [Joseph Batchik] updates per review 102061f [Joseph Batchik] updates per review 75313f5 [Joseph Batchik] small fixes c7f7b5a [Joseph Batchik] addeding input file name to Spark SQL --- .../apache/spark/rdd}/SqlNewHadoopRDD.scala | 34 +++++++++++-- .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../catalyst/expressions/InputFileName.scala | 49 +++++++++++++++++++ .../expressions/SparkPartitionID.scala | 2 + .../expressions/NondeterministicSuite.scala | 4 ++ .../org/apache/spark/sql/functions.scala | 9 ++++ .../spark/sql/parquet/ParquetRelation.scala | 3 +- .../spark/sql/ColumnExpressionSuite.scala | 17 ++++++- .../scala/org/apache/spark/sql/UDFSuite.scala | 17 ++++++- .../org/apache/spark/sql/hive/UDFSuite.scala | 6 --- 10 files changed, 128 insertions(+), 16 deletions(-) rename {sql/core/src/main/scala/org/apache/spark/sql/execution => core/src/main/scala/org/apache/spark/rdd}/SqlNewHadoopRDD.scala (91%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala similarity index 91% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala rename to core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala index 3d75b6a91def6..35e44cb59c1be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.rdd import java.text.SimpleDateFormat import java.util.Date -import org.apache.spark.{Partition => SparkPartition, _} +import scala.reflect.ClassTag + import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ @@ -30,12 +31,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, Utils} -import scala.reflect.ClassTag private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -62,7 +63,7 @@ private[spark] class SqlNewHadoopPartition( * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be * folded into core. */ -private[sql] class SqlNewHadoopRDD[K, V]( +private[spark] class SqlNewHadoopRDD[K, V]( @transient sc : SparkContext, broadcastedConf: Broadcast[SerializableConfiguration], @transient initDriverSideJobFuncOpt: Option[Job => Unit], @@ -128,6 +129,12 @@ private[sql] class SqlNewHadoopRDD[K, V]( val inputMetrics = context.taskMetrics .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { @@ -188,6 +195,8 @@ private[sql] class SqlNewHadoopRDD[K, V]( reader.close() reader = null + SqlNewHadoopRDD.unsetInputFileName() + if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || @@ -250,6 +259,21 @@ private[sql] class SqlNewHadoopRDD[K, V]( } private[spark] object SqlNewHadoopRDD { + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + /** * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to * the given function rather than the index of the partition. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 372f80d4a8b16..378df4f57d9e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -230,7 +230,8 @@ object FunctionRegistry { expression[Sha1]("sha"), expression[Sha1]("sha1"), expression[Sha2]("sha2"), - expression[SparkPartitionID]("spark_partition_id") + expression[SparkPartitionID]("spark_partition_id"), + expression[InputFileName]("input_file_name") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala new file mode 100644 index 0000000000000..1e74f716955e3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + */ +case class InputFileName() extends LeafExpression with Nondeterministic { + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override val prettyName = "INPUT_FILE_NAME" + + override protected def initInternal(): Unit = {} + + override protected def evalInternal(input: InternalRow): UTF8String = { + SqlNewHadoopRDD.getInputFileName() + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = " + + "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 3f6480bbf0114..4b1772a2deed5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -34,6 +34,8 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm @transient private[this] var partitionId: Int = _ + override val prettyName = "SPARK_PARTITION_ID" + override protected def initInternal(): Unit = { partitionId = TaskContext.getPartitionId() } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala index 82894822ab0f4..bf1c930c0bd0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala @@ -27,4 +27,8 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("SparkPartitionID") { checkEvaluation(SparkPartitionID(), 0) } + + test("InputFileName") { + checkEvaluation(InputFileName(), "") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4e68a88e7cda6..a2fece62f61f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -743,6 +743,15 @@ object functions { */ def sparkPartitionId(): Column = SparkPartitionID() + /** + * The file name of the current Spark task + * + * Note that this is indeterministic becuase it depends on what is currently being read in. + * + * @group normal_funcs + */ + def inputFileName(): Column = InputFileName() + /** * Computes the square root of the specified float value. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index cc6fa2b88663f..1a8176d8a80ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -39,11 +39,10 @@ import org.apache.parquet.{Log => ParquetLog} import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD} import org.apache.spark.rdd.RDD._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1f9f7118c3f04..5c1102410879a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,13 +22,16 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest { +class ColumnExpressionSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -489,6 +492,18 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("InputFileName") { + withTempPath { dir => + val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + data.write.parquet(dir.getCanonicalPath) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + .head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(inputFileName()).limit(1), Row("")) + } + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d9c8b380ef146..183dc3407b3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SQLTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { +class UDFSuite extends QueryTest with SQLTestUtils { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") @@ -58,6 +61,18 @@ class UDFSuite extends QueryTest { ctx.dropTempTable("tmp_table") } + test("SPARK-8005 input_file_name") { + withTempPath { dir => + val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + data.write.parquet(dir.getCanonicalPath) + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + assert(answer.contains(dir.getCanonicalPath)) + assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + ctx.dropTempTable("test_table") + } + } + test("error reporting for incorrect number of arguments") { val df = ctx.emptyDataFrame val e = intercept[AnalysisException] { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 37afc2142abf7..9b3ede43ee2d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -34,10 +34,4 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } - - test("SPARK-8003 spark_partition_id") { - val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying") - ctx.registerDataFrameAsTable(df, "test_table") - checkAnswer(ctx.sql("select spark_partition_id() from test_table LIMIT 1").toDF(), Row(0)) - } } From 76f2e393a5fad0db8b56c4b8dad5ef686bf140a4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 30 Jul 2015 00:46:36 -0700 Subject: [PATCH 153/219] [SPARK-9335] [TESTS] Enable Kinesis tests only when files in extras/kinesis-asl are changed Author: zsxwing Closes #7711 from zsxwing/SPARK-9335-test and squashes the following commits: c13ec2f [zsxwing] environs -> environ 69c2865 [zsxwing] Merge remote-tracking branch 'origin/master' into SPARK-9335-test ef84a08 [zsxwing] Revert "Modify the Kinesis project to trigger ENABLE_KINESIS_TESTS" f691028 [zsxwing] Modify the Kinesis project to trigger ENABLE_KINESIS_TESTS 7618205 [zsxwing] Enable Kinesis tests only when files in extras/kinesis-asl are changed --- dev/run-tests.py | 16 ++++++++++++++++ dev/sparktestsupport/modules.py | 14 ++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 1f0d218514f92..29420da9aa956 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -85,6 +85,13 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe return [f for f in raw_output.split('\n') if f] +def setup_test_environ(environ): + print("[info] Setup the following environment variables for tests: ") + for (k, v) in environ.items(): + print("%s=%s" % (k, v)) + os.environ[k] = v + + def determine_modules_to_test(changed_modules): """ Given a set of modules that have changed, compute the transitive closure of those modules' @@ -455,6 +462,15 @@ def main(): print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules)) + # setup environment variables + # note - the 'root' module doesn't collect environment variables for all modules. Because the + # environment variables should not be set if a module is not changed, even if running the 'root' + # module. So here we should use changed_modules rather than test_modules. + test_environ = {} + for m in changed_modules: + test_environ.update(m.environ) + setup_test_environ(test_environ) + test_modules = determine_modules_to_test(changed_modules) # license checks diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 3073d489bad4a..030d982e99106 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -29,7 +29,7 @@ class Module(object): changed. """ - def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), should_run_r_tests=False): """ @@ -43,6 +43,8 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= filename strings. :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in order to build and test this module (e.g. '-PprofileName'). + :param environ: A dict of environment variables that should be set when files in this + module are changed. :param sbt_test_goals: A set of SBT test goals for testing this module. :param python_test_goals: A set of Python test goals for testing this module. :param blacklisted_python_implementations: A set of Python implementations that are not @@ -55,6 +57,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.source_file_prefixes = source_file_regexes self.sbt_test_goals = sbt_test_goals self.build_profile_flags = build_profile_flags + self.environ = environ self.python_test_goals = python_test_goals self.blacklisted_python_implementations = blacklisted_python_implementations self.should_run_r_tests = should_run_r_tests @@ -126,15 +129,22 @@ def contains_file(self, filename): ) +# Don't set the dependencies because changes in other modules should not trigger Kinesis tests. +# Kinesis tests depends on external Amazon kinesis service. We should run these tests only when +# files in streaming_kinesis_asl are changed, so that if Kinesis experiences an outage, we don't +# fail other PRs. streaming_kinesis_asl = Module( name="kinesis-asl", - dependencies=[streaming], + dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", ], build_profile_flags=[ "-Pkinesis-asl", ], + environ={ + "ENABLE_KINESIS_TESTS": "1" + }, sbt_test_goals=[ "kinesis-asl/test", ] From 4a8bb9d00d8181aff5f5183194d9aa2a65deacdf Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 01:04:24 -0700 Subject: [PATCH 154/219] Revert "[SPARK-9458] Avoid object allocation in prefix generation." This reverts commit 9514d874f0cf61f1eb4ec4f5f66e053119f769c9. --- .../unsafe/sort/PrefixComparators.java | 16 ++++++ .../unsafe/sort/PrefixComparatorsSuite.scala | 12 +++++ .../execution/UnsafeExternalRowSorter.java | 2 +- .../spark/sql/execution/SortPrefixUtils.scala | 51 ++++++++++--------- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../org/apache/spark/sql/execution/sort.scala | 5 +- .../execution/RowFormatConvertersSuite.scala | 2 +- .../execution/UnsafeExternalSortSuite.scala | 10 ++-- 8 files changed, 67 insertions(+), 35 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index a9ee6042fec74..600aff7d15d8a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -29,6 +29,7 @@ private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); + public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); public static final class StringPrefixComparator extends PrefixComparator { @@ -54,6 +55,21 @@ public int compare(long a, long b) { public final long NULL_PREFIX = Long.MIN_VALUE; } + public static final class FloatPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + float a = Float.intBitsToFloat((int) aPrefix); + float b = Float.intBitsToFloat((int) bPrefix); + return Utils.nanSafeCompareFloats(a, b); + } + + public long computePrefix(float value) { + return Float.floatToIntBits(value) & 0xffffffffL; + } + + public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); + } + public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 26b7a9e816d1e..cf53a8ad21c60 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -55,6 +55,18 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } + test("float prefix comparator handles NaN properly") { + val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) + val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) + val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) + assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) + } + test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 8342833246f7d..4c3f2c6557140 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -121,7 +121,7 @@ Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 050d27f1460fb..2dee3542d6101 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BoundReference, SortOrder} +import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -39,54 +39,57 @@ object SortPrefixUtils { sortOrder.dataType match { case StringType => PrefixComparators.STRING case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType | DoubleType => PrefixComparators.DOUBLE + case FloatType => PrefixComparators.FLOAT + case DoubleType => PrefixComparators.DOUBLE case _ => NoOpPrefixComparator } } def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { - val bound = sortOrder.child.asInstanceOf[BoundReference] - val pos = bound.ordinal sortOrder.dataType match { - case StringType => - (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(row.getUTF8String(pos)) - } + case StringType => (row: InternalRow) => { + PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) + } case BooleanType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (row.getBoolean(pos)) 1 + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 else 0 } case ByteType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getByte(pos) + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Byte] } case ShortType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getShort(pos) + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Short] } case IntegerType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getInt(pos) + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Int] } case LongType => (row: InternalRow) => { - if (row.isNullAt(pos)) PrefixComparators.INTEGRAL.NULL_PREFIX else row.getLong(pos) + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX + else sortOrder.child.eval(row).asInstanceOf[Long] } case FloatType => (row: InternalRow) => { - if (row.isNullAt(pos)) { - PrefixComparators.DOUBLE.NULL_PREFIX - } else { - PrefixComparators.DOUBLE.computePrefix(row.getFloat(pos).toDouble) - } + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX + else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) } case DoubleType => (row: InternalRow) => { - if (row.isNullAt(pos)) { - PrefixComparators.DOUBLE.NULL_PREFIX - } else { - PrefixComparators.DOUBLE.computePrefix(row.getDouble(pos)) - } + val exprVal = sortOrder.child.eval(row) + if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX + else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) } case _ => (row: InternalRow) => 0L } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4ab2c41f1b339..f3ef066528ff8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -340,8 +340,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - TungstenSort.supportsSchema(child.schema)) { - execution.TungstenSort(sortExprs, global, child) + UnsafeExternalSort.supportsSchema(child.schema)) { + execution.UnsafeExternalSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index d0ad310062853..f82208868c3e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -97,7 +97,7 @@ case class ExternalSort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ -case class TungstenSort( +case class UnsafeExternalSort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, @@ -110,6 +110,7 @@ case class TungstenSort( if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { + assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) @@ -148,7 +149,7 @@ case class TungstenSort( } @DeveloperApi -object TungstenSort { +object UnsafeExternalSort { /** * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index c458f95ca1ab3..7b75f755918c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -31,7 +31,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 9cabc4b90bf8e..7a4baa9e4a49d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -42,7 +42,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -53,7 +53,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { try { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -68,7 +68,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) @@ -88,11 +88,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(TungstenSort.supportsSchema(inputDf.schema)) + assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, plan => ConvertToSafe( - TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) From 5ba2d44068b89fd8e81cfd24f49bf20d373f81b9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 01:21:39 -0700 Subject: [PATCH 155/219] Fix flaky HashedRelationSuite SparkEnv might not have been set in local unit tests. Author: Reynold Xin Closes #7784 from rxin/HashedRelationSuite and squashes the following commits: 435d64b [Reynold Xin] Fix flaky HashedRelationSuite --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 7 +++++-- .../spark/sql/execution/joins/HashedRelationSuite.scala | 6 +++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 7a507391316a9..26dbc911e9521 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -260,7 +260,10 @@ private[joins] final class UnsafeHashedRelation( val nKeys = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - val pageSizeBytes = SparkEnv.get.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") + + val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + .getSizeAsBytes("spark.buffer.pageSize", "64m") + binaryMap = new BytesToBytesMap( memoryManager, nKeys * 2, // reduce hash collision diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 941f6d4f6a450..8b1a9b21a96b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -33,7 +33,7 @@ class HashedRelationSuite extends SparkFunSuite { override def apply(row: InternalRow): InternalRow = row } - ignore("GeneralHashedRelation") { + test("GeneralHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[GeneralHashedRelation]) @@ -47,7 +47,7 @@ class HashedRelationSuite extends SparkFunSuite { assert(hashed.get(data(2)) === data2) } - ignore("UniqueKeyHashedRelation") { + test("UniqueKeyHashedRelation") { val data = Array(InternalRow(0), InternalRow(1), InternalRow(2)) val hashed = HashedRelation(data.iterator, keyProjection) assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) @@ -64,7 +64,7 @@ class HashedRelationSuite extends SparkFunSuite { assert(uniqHashed.getValue(InternalRow(10)) === null) } - ignore("UnsafeHashedRelation") { + test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) From 6175d6cfe795fbd88e3ee713fac375038a3993a8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 17:45:30 +0800 Subject: [PATCH 156/219] [SPARK-8838] [SQL] Add config to enable/disable merging part-files when merging parquet schema JIRA: https://issues.apache.org/jira/browse/SPARK-8838 Currently all part-files are merged when merging parquet schema. However, in case there are many part-files and we can make sure that all the part-files have the same schema as their summary file. If so, we provide a configuration to disable merging part-files when merging parquet schema. In short, we need to merge parquet schema because different summary files may contain different schema. But the part-files are confirmed to have the same schema with summary files. Author: Liang-Chi Hsieh Closes #7238 from viirya/option_partfile_merge and squashes the following commits: 71d5b5f [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge 8816f44 [Liang-Chi Hsieh] For comments. dbc8e6b [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge afc2fa1 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge d4ed7e6 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge df43027 [Liang-Chi Hsieh] Get dataStatuses' partitions based on all paths. 4eb2f00 [Liang-Chi Hsieh] Use given parameter. ea8f6e5 [Liang-Chi Hsieh] Correct the code comments. a57be0e [Liang-Chi Hsieh] Merge part-files if there are no summary files. 47df981 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge 4caf293 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into option_partfile_merge 0e734e0 [Liang-Chi Hsieh] Use correct API. 3b6be5b [Liang-Chi Hsieh] Fix key not found. 4bdd7e0 [Liang-Chi Hsieh] Don't read footer files if we can skip them. 8bbebcb [Liang-Chi Hsieh] Figure out how to test the config. bbd4ce7 [Liang-Chi Hsieh] Add config to enable/disable merging part-files when merging parquet schema. --- .../scala/org/apache/spark/sql/SQLConf.scala | 7 +++++ .../spark/sql/parquet/ParquetRelation.scala | 19 ++++++++++++- .../spark/sql/parquet/ParquetQuerySuite.scala | 27 +++++++++++++++++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index cdb0c7a1c07a7..2564bbd2077bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -247,6 +247,13 @@ private[spark] object SQLConf { "otherwise the schema is picked from the summary file or a random data file " + "if no summary file is available.") + val PARQUET_SCHEMA_RESPECT_SUMMARIES = booleanConf("spark.sql.parquet.respectSummaryFiles", + defaultValue = Some(false), + doc = "When true, we make assumption that all part-files of Parquet are consistent with " + + "summary files and we will ignore them when merging schema. Otherwise, if this is " + + "false, which is the default, we will merge all part-files. This should be considered " + + "as expert-only option, and shouldn't be enabled before knowing what it means exactly.") + val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", defaultValue = Some(false), doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 1a8176d8a80ab..b4337a48dbd80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -124,6 +124,9 @@ private[sql] class ParquetRelation( .map(_.toBoolean) .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + private val mergeRespectSummaries = + sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) + private val maybeMetastoreSchema = parameters .get(ParquetRelation.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) @@ -421,7 +424,21 @@ private[sql] class ParquetRelation( val filesToTouch = if (shouldMergeSchemas) { // Also includes summary files, 'cause there might be empty partition directories. - (metadataStatuses ++ commonMetadataStatuses ++ dataStatuses).toSeq + + // If mergeRespectSummaries config is true, we assume that all part-files are the same for + // their schema with summary files, so we ignore them when merging schema. + // If the config is disabled, which is the default setting, we merge all part-files. + // In this mode, we only need to merge schemas contained in all those summary files. + // You should enable this configuration only if you are very sure that for the parquet + // part-files to read there are corresponding summary files containing correct schema. + + val needMerged: Seq[FileStatus] = + if (mergeRespectSummaries) { + Seq() + } else { + dataStatuses + } + (metadataStatuses ++ commonMetadataStatuses ++ needMerged).toSeq } else { // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet // don't have this. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index c037faf4cfd92..a95f70f2bba69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.parquet +import java.io.File + import org.apache.hadoop.fs.Path import org.apache.spark.sql.types._ import org.apache.spark.sql.{QueryTest, Row, SQLConf} +import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. @@ -123,6 +126,30 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { } } + test("Enabling/disabling merging partfiles when merging parquet schema") { + def testSchemaMerging(expectedColumnNumber: Int): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString) + sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString) + // delete summary files, so if we don't merge part-files, one column will not be included. + Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata")) + Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata")) + assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber) + } + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "true") { + testSchemaMerging(2) + } + + withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "true", + SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES.key -> "false") { + testSchemaMerging(3) + } + } + test("Enabling/disabling schema merging") { def testSchemaMerging(expectedColumnNumber: Int): Unit = { withTempDir { dir => From d31c618e3c8838f8198556876b9dcbbbf835f7b2 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 30 Jul 2015 07:49:10 -0700 Subject: [PATCH 157/219] [SPARK-7368] [MLLIB] Add QR decomposition for RowMatrix jira: https://issues.apache.org/jira/browse/SPARK-7368 Add QR decomposition for RowMatrix. I'm not sure what's the blueprint about the distributed Matrix from community and whether this will be a desirable feature , so I sent a prototype for discussion. I'll go on polish the code and provide ut and performance statistics if it's acceptable. The implementation refers to the [paper: https://www.cs.purdue.edu/homes/dgleich/publications/Benson%202013%20-%20direct-tsqr.pdf] Austin R. Benson, David F. Gleich, James Demmel. "Direct QR factorizations for tall-and-skinny matrices in MapReduce architectures", 2013 IEEE International Conference on Big Data, which is a stable algorithm with good scalability. Currently I tried it on a 400000 * 500 rowMatrix (16 partitions) and it can bring down the computation time from 8.8 mins (using breeze.linalg.qr.reduced) to 2.6 mins on a 4 worker cluster. I think there will still be some room for performance improvement. Any trial and suggestion is welcome. Author: Yuhao Yang Closes #5909 from hhbyyh/qrDecomposition and squashes the following commits: cec797b [Yuhao Yang] remove unnecessary qr 0fb1012 [Yuhao Yang] hierarchy R computing 3fbdb61 [Yuhao Yang] update qr to indirect and add ut 0d913d3 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into qrDecomposition 39213c3 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into qrDecomposition c0fc0c7 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into qrDecomposition 39b0b22 [Yuhao Yang] initial draft for discussion --- .../linalg/SingularValueDecomposition.scala | 8 ++++ .../mllib/linalg/distributed/RowMatrix.scala | 46 ++++++++++++++++++- .../linalg/distributed/RowMatrixSuite.scala | 17 +++++++ 3 files changed, 70 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index 9669c364bad8f..b416d50a5631e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -25,3 +25,11 @@ import org.apache.spark.annotation.Experimental */ @Experimental case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) + +/** + * :: Experimental :: + * Represents QR factors. + */ +@Experimental +case class QRDecomposition[UType, VType](Q: UType, R: VType) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 1626da9c3d2ee..bfc90c9ef8527 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -22,7 +22,7 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, - svd => brzSvd} + svd => brzSvd, MatrixSingularException, inv} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -497,6 +497,50 @@ class RowMatrix( columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma) } + /** + * Compute QR decomposition for [[RowMatrix]]. The implementation is designed to optimize the QR + * decomposition (factorization) for the [[RowMatrix]] of a tall and skinny shape. + * Reference: + * Paul G. Constantine, David F. Gleich. "Tall and skinny QR factorizations in MapReduce + * architectures" ([[http://dx.doi.org/10.1145/1996092.1996103]]) + * + * @param computeQ whether to computeQ + * @return QRDecomposition(Q, R), Q = null if computeQ = false. + */ + def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = { + val col = numCols().toInt + // split rows horizontally into smaller matrices, and compute QR for each of them + val blockQRs = rows.glom().map { partRows => + val bdm = BDM.zeros[Double](partRows.length, col) + var i = 0 + partRows.foreach { row => + bdm(i, ::) := row.toBreeze.t + i += 1 + } + breeze.linalg.qr.reduced(bdm).r + } + + // combine the R part from previous results vertically into a tall matrix + val combinedR = blockQRs.treeReduce{ (r1, r2) => + val stackedR = BDM.vertcat(r1, r2) + breeze.linalg.qr.reduced(stackedR).r + } + val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix) + val finalQ = if (computeQ) { + try { + val invR = inv(combinedR) + this.multiply(Matrices.fromBreeze(invR)) + } catch { + case err: MatrixSingularException => + logWarning("R is not invertible and return Q as null") + null + } + } else { + null + } + QRDecomposition(finalQ, finalR) + } + /** * Find all similar columns using the DIMSUM sampling algorithm, described in two papers * diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index b6cb53d0c743e..283ffec1d49d7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg.distributed import scala.util.Random +import breeze.numerics.abs import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} import org.apache.spark.SparkFunSuite @@ -238,6 +239,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("QR Decomposition") { + for (mat <- Seq(denseMat, sparseMat)) { + val result = mat.tallSkinnyQR(true) + val expected = breeze.linalg.qr.reduced(mat.toBreeze()) + val calcQ = result.Q + val calcR = result.R + assert(closeToZero(abs(expected.q) - abs(calcQ.toBreeze()))) + assert(closeToZero(abs(expected.r) - abs(calcR.toBreeze.asInstanceOf[BDM[Double]]))) + assert(closeToZero(calcQ.multiply(calcR).toBreeze - mat.toBreeze())) + // Decomposition without computing Q + val rOnly = mat.tallSkinnyQR(computeQ = false) + assert(rOnly.Q == null) + assert(closeToZero(abs(expected.r) - abs(rOnly.R.toBreeze.asInstanceOf[BDM[Double]]))) + } + } } class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext { From c5815930be46a89469440b7c61b59764fb67a54c Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Jul 2015 07:56:15 -0700 Subject: [PATCH 158/219] [SPARK-5561] [MLLIB] Generalized PeriodicCheckpointer for RDDs and Graphs PeriodicGraphCheckpointer was introduced for Latent Dirichlet Allocation (LDA), but it was meant to be generalized to work with Graphs, RDDs, and other data structures based on RDDs. This PR generalizes it. For those who are not familiar with the periodic checkpointer, it tries to automatically handle persisting/unpersisting and checkpointing/removing checkpoint files in a lineage of RDD-based objects. I need it generalized to use with GradientBoostedTrees [https://issues.apache.org/jira/browse/SPARK-6684]. It should be useful for other iterative algorithms as well. Changes I made: * Copied PeriodicGraphCheckpointer to PeriodicCheckpointer. * Within PeriodicCheckpointer, I created abstract methods for the basic operations (checkpoint, persist, etc.). * The subclasses for Graphs and RDDs implement those abstract methods. * I copied the test suite for the graph checkpointer and made tiny modifications to make it work for RDDs. To review this PR, I recommend doing 2 diffs: (1) diff between the old PeriodicGraphCheckpointer.scala and the new PeriodicCheckpointer.scala (2) diff between the 2 test suites CCing andrewor14 in case there are relevant changes to checkpointing. CCing feynmanliang in case you're interested in learning about checkpointing. CCing mengxr for final OK. Thanks all! Author: Joseph K. Bradley Closes #7728 from jkbradley/gbt-checkpoint and squashes the following commits: d41902c [Joseph K. Bradley] Oops, forgot to update an extra time in the checkpointer tests, after the last commit. I'll fix that. I'll also make some of the checkpointer methods protected, which I should have done before. 32b23b8 [Joseph K. Bradley] fixed usage of checkpointer in lda 0b3dbc0 [Joseph K. Bradley] Changed checkpointer constructor not to take initial data. 568918c [Joseph K. Bradley] Generalized PeriodicGraphCheckpointer to PeriodicCheckpointer, with subclasses for RDDs and Graphs. --- .../spark/mllib/clustering/LDAOptimizer.scala | 6 +- .../mllib/impl/PeriodicCheckpointer.scala | 154 ++++++++++++++++ .../impl/PeriodicGraphCheckpointer.scala | 105 ++--------- .../mllib/impl/PeriodicRDDCheckpointer.scala | 97 ++++++++++ .../impl/PeriodicGraphCheckpointerSuite.scala | 16 +- .../impl/PeriodicRDDCheckpointerSuite.scala | 173 ++++++++++++++++++ 6 files changed, 452 insertions(+), 99 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 7e75e7083acb5..4b90fbdf0ce7e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -142,8 +142,8 @@ final class EMLDAOptimizer extends LDAOptimizer { this.k = k this.vocabSize = docs.take(1).head._2.size this.checkpointInterval = lda.getCheckpointInterval - this.graphCheckpointer = new - PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + checkpointInterval, graph.vertices.sparkContext) this.globalTopicTotals = computeGlobalTopicTotals() this } @@ -188,7 +188,7 @@ final class EMLDAOptimizer extends LDAOptimizer { // Update the vertex descriptors with the new counts. val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) graph = newGraph - graphCheckpointer.updateGraph(newGraph) + graphCheckpointer.update(newGraph) globalTopicTotals = computeGlobalTopicTotals() this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala new file mode 100644 index 0000000000000..72d3aabc9b1f4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.storage.StorageLevel + + +/** + * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs + * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to + * the distributed data type (RDD, Graph, etc.). + * + * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing, + * as well as unpersisting and removing checkpoint files. + * + * Users should call update() when a new Dataset has been created, + * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are + * responsible for materializing the Dataset to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. + * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Datasets should be + * checkpointed). + * - This class removes checkpoint files once later Datasets have been checkpointed. + * However, references to the older Datasets will still return isCheckpointed = true. + * + * @param checkpointInterval Datasets will be checkpointed at this interval + * @param sc SparkContext for the Datasets given to this checkpointer + * @tparam T Dataset type, such as RDD[Double] + */ +private[mllib] abstract class PeriodicCheckpointer[T]( + val checkpointInterval: Int, + val sc: SparkContext) extends Logging { + + /** FIFO queue of past checkpointed Datasets */ + private val checkpointQueue = mutable.Queue[T]() + + /** FIFO queue of past persisted Datasets */ + private val persistedQueue = mutable.Queue[T]() + + /** Number of times [[update()]] has been called */ + private var updateCount = 0 + + /** + * Update with a new Dataset. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the Dataset + * has been materialized. + * + * @param newData New Dataset created from previous Datasets in the lineage. + */ + def update(newData: T): Unit = { + persist(newData) + persistedQueue.enqueue(newData) + // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class: + // Users should call [[update()]] when a new Dataset has been created, + // before the Dataset has been materialized. + while (persistedQueue.size > 3) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + checkpoint(newData) + checkpointQueue.enqueue(newData) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (isCheckpointed(checkpointQueue.head)) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** Checkpoint the Dataset */ + protected def checkpoint(data: T): Unit + + /** Return true iff the Dataset is checkpointed */ + protected def isCheckpointed(data: T): Boolean + + /** + * Persist the Dataset. + * Note: This should handle checking the current [[StorageLevel]] of the Dataset. + */ + protected def persist(data: T): Unit + + /** Unpersist the Dataset */ + protected def unpersist(data: T): Unit + + /** Get list of checkpoint files for this given Dataset */ + protected def getCheckpointFiles(data: T): Iterable[String] + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + getCheckpointFiles(old).foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 6e5dd119dd653..11a059536c50c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -17,11 +17,7 @@ package org.apache.spark.mllib.impl -import scala.collection.mutable - -import org.apache.hadoop.fs.{Path, FileSystem} - -import org.apache.spark.Logging +import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel @@ -31,12 +27,12 @@ import org.apache.spark.storage.StorageLevel * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as * unpersisting and removing checkpoint files. * - * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * Users should call update() when a new graph has been created, * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are * responsible for materializing the graph to ensure that persisting and checkpointing actually * occur. * - * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * When update() is called, this does the following: * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. * - Unpersist graphs from queue until there are at most 3 persisted graphs. * - If using checkpointing and the checkpoint interval has been reached, @@ -52,7 +48,7 @@ import org.apache.spark.storage.StorageLevel * Example usage: * {{{ * val (graph1, graph2, graph3, ...) = ... - * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * val cp = new PeriodicGraphCheckpointer(2, sc) * graph1.vertices.count(); graph1.edges.count() * // persisted: graph1 * cp.updateGraph(graph2) @@ -73,99 +69,30 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param currentGraph Initial graph * @param checkpointInterval Graphs will be checkpointed at this interval * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + * TODO: Move this out of MLlib? */ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( - var currentGraph: Graph[VD, ED], - val checkpointInterval: Int) extends Logging { - - /** FIFO queue of past checkpointed RDDs */ - private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() - - /** FIFO queue of past persisted RDDs */ - private val persistedQueue = mutable.Queue[Graph[VD, ED]]() - - /** Number of times [[updateGraph()]] has been called */ - private var updateCount = 0 - - /** - * Spark Context for the Graphs given to this checkpointer. - * NOTE: This code assumes that only one SparkContext is used for the given graphs. - */ - private val sc = currentGraph.vertices.sparkContext + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { - updateGraph(currentGraph) + override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() - /** - * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. - * Since this handles persistence and checkpointing, this should be called before the graph - * has been materialized. - * - * @param newGraph New graph created from previous graphs in the lineage. - */ - def updateGraph(newGraph: Graph[VD, ED]): Unit = { - if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { - newGraph.persist() - } - persistedQueue.enqueue(newGraph) - // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: - // Users should call [[updateGraph()]] when a new graph has been created, - // before the graph has been materialized. - while (persistedQueue.size > 3) { - val graphToUnpersist = persistedQueue.dequeue() - graphToUnpersist.unpersist(blocking = false) - } - updateCount += 1 + override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed - // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { - // Add new checkpoint before removing old checkpoints. - newGraph.checkpoint() - checkpointQueue.enqueue(newGraph) - // Remove checkpoints before the latest one. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // Delete the oldest checkpoint only if the next checkpoint exists. - if (checkpointQueue.get(1).get.isCheckpointed) { - removeCheckpointFile() - } else { - canDelete = false - } - } + override protected def persist(data: Graph[VD, ED]): Unit = { + if (data.vertices.getStorageLevel == StorageLevel.NONE) { + data.persist() } } - /** - * Call this at the end to delete any remaining checkpoint files. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.size > 0) { - removeCheckpointFile() - } - } + override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) - /** - * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. - * This prints a warning but does not fail if the files cannot be removed. - */ - private def removeCheckpointFile(): Unit = { - val old = checkpointQueue.dequeue() - // Since the old checkpoint is not deleted by Spark, we manually delete it. - val fs = FileSystem.get(sc.hadoopConfiguration) - old.getCheckpointFiles.foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } + override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { + data.getCheckpointFiles } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala new file mode 100644 index 0000000000000..f31ed2aa90a64 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.spark.SparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing RDDs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new RDD has been created, + * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are + * responsible for materializing the RDD to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs. + * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which RDDs should be + * checkpointed). + * - This class removes checkpoint files once later RDDs have been checkpointed. + * However, references to the older RDDs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (rdd1, rdd2, rdd3, ...) = ... + * val cp = new PeriodicRDDCheckpointer(2, sc) + * rdd1.count(); + * // persisted: rdd1 + * cp.update(rdd2) + * rdd2.count(); + * // persisted: rdd1, rdd2 + * // checkpointed: rdd2 + * cp.update(rdd3) + * rdd3.count(); + * // persisted: rdd1, rdd2, rdd3 + * // checkpointed: rdd2 + * cp.update(rdd4) + * rdd4.count(); + * // persisted: rdd2, rdd3, rdd4 + * // checkpointed: rdd4 + * cp.update(rdd5) + * rdd5.count(); + * // persisted: rdd3, rdd4, rdd5 + * // checkpointed: rdd4 + * }}} + * + * @param checkpointInterval RDDs will be checkpointed at this interval + * @tparam T RDD element type + * + * TODO: Move this out of MLlib? + */ +private[mllib] class PeriodicRDDCheckpointer[T]( + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { + + override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() + + override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed + + override protected def persist(data: RDD[T]): Unit = { + if (data.getStorageLevel == StorageLevel.NONE) { + data.persist() + } + } + + override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) + + override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { + data.getCheckpointFile.map(x => x) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index d34888af2d73b..e331c75989187 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo import PeriodicGraphCheckpointerSuite._ - // TODO: Do I need to call count() on the graphs' RDDs? - test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, 10) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) checkPersistence(graphsToCheck, iteration) iteration += 1 @@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var graphsToCheck = Seq.empty[GraphToCheck] sc.setCheckpointDir(path) val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) @@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graph.vertices.count() graph.edges.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) @@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite { } else { // Graph should never be checkpointed assert(!graph.isCheckpointed, "Graph should never have been checkpointed") - assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files") } } catch { case e: AssertionError => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala new file mode 100644 index 0000000000000..b2a459a68b5fa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PeriodicRDDCheckpointerSuite._ + + test("Persisting") { + var rddsToCheck = Seq.empty[RDDToCheck] + + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) + checkpointer.update(rdd1) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkPersistence(rddsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkPersistence(rddsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var rddsToCheck = Seq.empty[RDDToCheck] + sc.setCheckpointDir(path) + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) + checkpointer.update(rdd1) + rdd1.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkCheckpoint(rddsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rdd.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkCheckpoint(rddsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + rddsToCheck.foreach { rdd => + confirmCheckpointRemoved(rdd.rdd) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicRDDCheckpointerSuite { + + case class RDDToCheck(rdd: RDD[Double], gIndex: Int) + + def createRDD(sc: SparkContext): RDD[Double] = { + sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0)) + } + + def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = { + rdds.foreach { g => + checkPersistence(g.rdd, g.gIndex, iteration) + } + } + + /** + * Check storage level of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(rdd.getStorageLevel == StorageLevel.NONE) + } else { + assert(rdd.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n") + } + } + + def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = { + rdds.reverse.foreach { g => + checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(rdd: RDD[_]): Unit = { + // Note: We cannot check rdd.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this rdd.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration) + rdd.getCheckpointFile.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkCheckpoint( + rdd: RDD[_], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd) + // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(rdd.isCheckpointed, "RDD should be checkpointed") + assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(rdd) + } + } else { + // RDD should never be checkpointed + assert(!rdd.isCheckpointed, "RDD should never have been checkpointed") + assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" + + s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} From d212a314227dec26c0dbec8ed3422d0ec8f818f9 Mon Sep 17 00:00:00 2001 From: zhangjiajin Date: Thu, 30 Jul 2015 08:14:09 -0700 Subject: [PATCH 159/219] [SPARK-8998] [MLLIB] Distribute PrefixSpan computation for large projected databases Continuation of work by zhangjiajin Closes #7412 Author: zhangjiajin Author: Feynman Liang Author: zhang jiajin Closes #7783 from feynmanliang/SPARK-8998-improve-distributed and squashes the following commits: a61943d [Feynman Liang] Collect small patterns to local 4ddf479 [Feynman Liang] Parallelize freqItemCounts ad23aa9 [zhang jiajin] Merge pull request #1 from feynmanliang/SPARK-8998-collectBeforeLocal 87fa021 [Feynman Liang] Improve extend prefix readability c2caa5c [Feynman Liang] Readability improvements and comments 1235cfc [Feynman Liang] Use Iterable[Array[_]] over Array[Array[_]] for database da0091b [Feynman Liang] Use lists for prefixes to reuse data cb2a4fc [Feynman Liang] Inline code for readability 01c9ae9 [Feynman Liang] Add getters 6e149fa [Feynman Liang] Fix splitPrefixSuffixPairs 64271b3 [zhangjiajin] Modified codes according to comments. d2250b7 [zhangjiajin] remove minPatternsBeforeLocalProcessing, add maxSuffixesBeforeLocalProcessing. b07e20c [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark into CollectEnoughPrefixes 095aa3a [zhangjiajin] Modified the code according to the review comments. baa2885 [zhangjiajin] Modified the code according to the review comments. 6560c69 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixeSpan a8fde87 [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark 4dd1c8a [zhangjiajin] initialize file before rebase. 078d410 [zhangjiajin] fix a scala style error. 22b0ef4 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixSpan. ca9c4c8 [zhangjiajin] Modified the code according to the review comments. 574e56c [zhangjiajin] Add new object LocalPrefixSpan, and do some optimization. ba5df34 [zhangjiajin] Fix a Scala style error. 4c60fb3 [zhangjiajin] Fix some Scala style errors. 1dd33ad [zhangjiajin] Modified the code according to the review comments. 89bc368 [zhangjiajin] Fixed a Scala style error. a2eb14c [zhang jiajin] Delete PrefixspanSuite.scala 951fd42 [zhang jiajin] Delete Prefixspan.scala 575995f [zhangjiajin] Modified the code according to the review comments. 91fd7e6 [zhangjiajin] Add new algorithm PrefixSpan and test file. --- .../spark/mllib/fpm/LocalPrefixSpan.scala | 6 +- .../apache/spark/mllib/fpm/PrefixSpan.scala | 203 +++++++++++++----- .../spark/mllib/fpm/PrefixSpanSuite.scala | 21 +- 3 files changed, 161 insertions(+), 69 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index 7ead6327486cc..0ea792081086d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { minCount: Long, maxPatternLength: Int, prefixes: List[Int], - database: Array[Array[Int]]): Iterator[(List[Int], Long)] = { + database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = { if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty val frequentItemAndCounts = getFreqItemAndCounts(minCount, database) val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains)) @@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { } } - def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = { + def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = { database .map(getSuffix(prefix, _)) .filter(_.nonEmpty) @@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable { */ private def getFreqItemAndCounts( minCount: Long, - database: Array[Array[Int]]): mutable.Map[Int, Long] = { + database: Iterable[Array[Int]]): mutable.Map[Int, Long] = { // TODO: use PrimitiveKeyOpenHashMap val counts = mutable.Map[Int, Long]().withDefaultValue(0L) database.foreach { sequence => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 6f52db7b073ae..e6752332cdeeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.fpm +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD @@ -43,28 +45,45 @@ class PrefixSpan private ( private var minSupport: Double, private var maxPatternLength: Int) extends Logging with Serializable { + /** + * The maximum number of items allowed in a projected database before local processing. If a + * projected database exceeds this size, another iteration of distributed PrefixSpan is run. + */ + // TODO: make configurable with a better default value, 10000 may be too small + private val maxLocalProjDBSize: Long = 10000 + /** * Constructs a default instance with default parameters * {minSupport: `0.1`, maxPatternLength: `10`}. */ def this() = this(0.1, 10) + /** + * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered + * frequent). + */ + def getMinSupport: Double = this.minSupport + /** * Sets the minimal support level (default: `0.1`). */ def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1, - "The minimum support value must be between 0 and 1, including 0 and 1.") + require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") this.minSupport = minSupport this } + /** + * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. + */ + def getMaxPatternLength: Double = this.maxPatternLength + /** * Sets maximal pattern length (default: `10`). */ def setMaxPatternLength(maxPatternLength: Int): this.type = { - require(maxPatternLength >= 1, - "The maximum pattern length value must be greater than 0.") + // TODO: support unbounded pattern length when maxPatternLength = 0 + require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") this.maxPatternLength = maxPatternLength this } @@ -78,81 +97,153 @@ class PrefixSpan private ( * the value of pair is the pattern's count. */ def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + val sc = sequences.sparkContext + if (sequences.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - val minCount = getMinCount(sequences) - val lengthOnePatternsAndCounts = - getFreqItemAndCounts(minCount, sequences).collect() - val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( - lengthOnePatternsAndCounts.map(_._1), sequences) - val groupedProjectedDatabase = prefixAndProjectedDatabase - .map(x => (x._1.toSeq, x._2)) - .groupByKey() - .map(x => (x._1.toArray, x._2.toArray)) - val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) - val lengthOnePatternsAndCountsRdd = - sequences.sparkContext.parallelize( - lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) - val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns - allPatterns + + // Convert min support to a min number of transactions for this dataset + val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + + // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold + val freqItemCounts = sequences + .flatMap(seq => seq.distinct.map(item => (item, 1L))) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + .collect() + + // Pairs of (length 1 prefix, suffix consisting of frequent items) + val itemSuffixPairs = { + val freqItems = freqItemCounts.map(_._1).toSet + sequences.flatMap { seq => + val filteredSeq = seq.filter(freqItems.contains(_)) + freqItems.flatMap { item => + val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) + candidateSuffix match { + case suffix if !suffix.isEmpty => Some((List(item), suffix)) + case _ => None + } + } + } + } + + // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. + // frequent length-one prefixes) + var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2)) + + // Remaining work to be locally and distributively processed respectfully + var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) + + // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have + // projected database sizes <= `maxLocalProjDBSize`) + while (pairsForDistributed.count() != 0) { + val (nextPatternAndCounts, nextPrefixSuffixPairs) = + extendPrefixes(minCount, pairsForDistributed) + pairsForDistributed.unpersist() + val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) + pairsForDistributed = largerPairsPart + pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) + pairsForLocal ++= smallerPairsPart + resultsAccumulator ++= nextPatternAndCounts.collect() + } + + // Process the small projected databases locally + val remainingResults = getPatternsInLocal( + minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) + + (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) + .map { case (pattern, count) => (pattern.toArray, count) } } + /** - * Get the minimum count (sequences count * minSupport). - * @param sequences input data set, contains a set of sequences, - * @return minimum count, + * Partitions the prefix-suffix pairs by projected database size. + * @param prefixSuffixPairs prefix (length n) and suffix pairs, + * @return prefix-suffix pairs partitioned by whether their projected database size is <= or + * greater than [[maxLocalProjDBSize]] */ - private def getMinCount(sequences: RDD[Array[Int]]): Long = { - if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = { + val prefixToSuffixSize = prefixSuffixPairs + .aggregateByKey(0)( + seqOp = { case (count, suffix) => count + suffix.length }, + combOp = { _ + _ }) + val smallPrefixes = prefixToSuffixSize + .filter(_._2 <= maxLocalProjDBSize) + .keys + .collect() + .toSet + val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } + val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } + (small.collect(), large) } /** - * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the absolute minimum count - * @param sequences original sequences data - * @return array of item and count pair + * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes + * and remaining work. + * @param minCount minimum count + * @param prefixSuffixPairs prefix (length N) and suffix pairs, + * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended + * prefix, corresponding suffix) pairs. */ - private def getFreqItemAndCounts( + private def extendPrefixes( minCount: Long, - sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { - sequences.flatMap(_.distinct.map((_, 1L))) + prefixSuffixPairs: RDD[(List[Int], Array[Int])]) + : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = { + + // (length N prefix, item from suffix) pairs and their corresponding number of occurrences + // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` + val prefixItemPairAndCounts = prefixSuffixPairs + .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) } .reduceByKey(_ + _) .filter(_._2 >= minCount) - } - /** - * Get the frequent prefixes' projected database. - * @param frequentPrefixes frequent prefixes - * @param sequences sequences data - * @return prefixes and projected database - */ - private def getPrefixAndProjectedDatabase( - frequentPrefixes: Array[Int], - sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { - val filteredSequences = sequences.map { p => - p.filter (frequentPrefixes.contains(_) ) - } - filteredSequences.flatMap { x => - frequentPrefixes.map { y => - val sub = LocalPrefixSpan.getSuffix(y, x) - (Array(y), sub) - }.filter(_._2.nonEmpty) - } + // Map from prefix to set of possible next items from suffix + val prefixToNextItems = prefixItemPairAndCounts + .keys + .groupByKey() + .mapValues(_.toSet) + .collect() + .toMap + + + // Frequent patterns with length N+1 and their corresponding counts + val extendedPrefixAndCounts = prefixItemPairAndCounts + .map { case ((prefix, item), count) => (item :: prefix, count) } + + // Remaining work, all prefixes will have length N+1 + val extendedPrefixAndSuffix = prefixSuffixPairs + .filter(x => prefixToNextItems.contains(x._1)) + .flatMap { case (prefix, suffix) => + val frequentNextItems = prefixToNextItems(prefix) + val filteredSuffix = suffix.filter(frequentNextItems.contains(_)) + frequentNextItems.flatMap { item => + LocalPrefixSpan.getSuffix(item, filteredSuffix) match { + case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) + case _ => None + } + } + } + + (extendedPrefixAndCounts, extendedPrefixAndSuffix) } /** - * calculate the patterns in local. + * Calculate the patterns in local. * @param minCount the absolute minimum count - * @param data patterns and projected sequences data data + * @param data prefixes and projected sequences data data * @return patterns */ private def getPatternsInLocal( minCount: Long, - data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { - data.flatMap { case (prefix, projDB) => - LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB) - .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) } + data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = { + data.flatMap { + case (prefix, projDB) => + LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB) + .map { case (pattern: List[Int], count: Long) => + (pattern.reverse, count) + } } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 9f107c89f6d80..6dd2dc926acc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val rdd = sc.parallelize(sequences, 2).cache() - def compareResult( - expectedValue: Array[(Array[Int], Long)], - actualValue: Array[(Array[Int], Long)]): Boolean = { - expectedValue.map(x => (x._1.toSeq, x._2)).toSet == - actualValue.map(x => (x._1.toSeq, x._2)).toSet - } - val prefixspan = new PrefixSpan() .setMinSupport(0.33) .setMaxPatternLength(50) @@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue1, result1.collect())) + assert(compareResults(expectedValue1, result1.collect())) prefixspan.setMinSupport(0.5).setMaxPatternLength(50) val result2 = prefixspan.run(rdd) @@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4), 4L), (Array(5), 3L) ) - assert(compareResult(expectedValue2, result2.collect())) + assert(compareResults(expectedValue2, result2.collect())) prefixspan.setMinSupport(0.33).setMaxPatternLength(2) val result3 = prefixspan.run(rdd) @@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(4, 5), 2L), (Array(5), 3L) ) - assert(compareResult(expectedValue3, result3.collect())) + assert(compareResults(expectedValue3, result3.collect())) + } + + private def compareResults( + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Boolean = { + expectedValue.map(x => (x._1.toSeq, x._2)).toSet == + actualValue.map(x => (x._1.toSeq, x._2)).toSet } + } From 9c0501c5d04d83ca25ce433138bf64df6a14dc58 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 30 Jul 2015 08:20:52 -0700 Subject: [PATCH 160/219] [SPARK-] [MLLIB] minor fix on tokenizer doc A trivial fix for the comments of RegexTokenizer. Maybe this is too small, yet I just noticed it and think it can be quite misleading. I can create a jira if necessary. Author: Yuhao Yang Closes #7791 from hhbyyh/docFix and squashes the following commits: cdf2542 [Yuhao Yang] minor fix on tokenizer doc --- .../src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 0b3af4747e693..248288ca73e99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -50,7 +50,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S /** * :: Experimental :: * A regex based tokenizer that extracts tokens either by using the provided regex pattern to split - * the text (default) or repeatedly matching the regex (if `gaps` is true). + * the text (default) or repeatedly matching the regex (if `gaps` is false). * Optional parameters also allow filtering tokens using a minimal length. * It returns an array of strings that can be empty. */ From a6e53a9c8b24326d1b6dca7a0e36ce6c643daa77 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Thu, 30 Jul 2015 08:52:01 -0700 Subject: [PATCH 161/219] [SPARK-9225] [MLLIB] LDASuite needs unit tests for empty documents Add unit tests for running LDA with empty documents. Both EMLDAOptimizer and OnlineLDAOptimizer are tested. feynmanliang Author: Meihua Wu Closes #7620 from rotationsymmetry/SPARK-9225 and squashes the following commits: 3ed7c88 [Meihua Wu] Incorporate reviewer's further comments f9432e8 [Meihua Wu] Incorporate reviewer's comments 8e1b9ec [Meihua Wu] Merge remote-tracking branch 'upstream/master' into SPARK-9225 ad55665 [Meihua Wu] Add unit tests for running LDA with empty documents --- .../spark/mllib/clustering/LDASuite.scala | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index b91c7cefed22e..61d2edfd9fb5f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -390,6 +390,46 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("EMLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new EMLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + + test("OnlineLDAOptimizer with empty docs") { + val vocabSize = 6 + val emptyDocsArray = Array.fill(6)(Vectors.sparse(vocabSize, Array.empty, Array.empty)) + val emptyDocs = emptyDocsArray + .zipWithIndex.map { case (wordCounts, docId) => + (docId.toLong, wordCounts) + } + val distributedEmptyDocs = sc.parallelize(emptyDocs, 2) + + val op = new OnlineLDAOptimizer() + val lda = new LDA() + .setK(3) + .setMaxIterations(5) + .setSeed(12345) + .setOptimizer(op) + + val model = lda.run(distributedEmptyDocs) + assert(model.vocabSize === vocabSize) + } + } private[clustering] object LDASuite { From ed3cb1d21c73645c8f6e6ee08181f876fc192e41 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 30 Jul 2015 09:19:55 -0700 Subject: [PATCH 162/219] [SPARK-9277] [MLLIB] SparseVector constructor must throw an error when declared number of elements less than array length Check that SparseVector size is at least as big as the number of indices/values provided. And add tests for constructor checks. CC MechCoder jkbradley -- I am not sure if a change needs to also happen in the Python API? I didn't see it had any similar checks to begin with, but I don't know it well. Author: Sean Owen Closes #7794 from srowen/SPARK-9277 and squashes the following commits: e8dc31e [Sean Owen] Fix scalastyle 6ffe34a [Sean Owen] Check that SparseVector size is at least as big as the number of indices/values provided. And add tests for constructor checks. --- .../org/apache/spark/mllib/linalg/Vectors.scala | 2 ++ .../apache/spark/mllib/linalg/VectorsSuite.scala | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 0cb28d78bec05..23c2c16d68d9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -637,6 +637,8 @@ class SparseVector( require(indices.length == values.length, "Sparse vectors require that the dimension of the" + s" indices match the dimension of the values. You provided ${indices.length} indices and " + s" ${values.length} values.") + require(indices.length <= size, s"You provided ${indices.length} indices and values, " + + s"which exceeds the specified vector size ${size}.") override def toString: String = s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 03be4119bdaca..1c37ea5123e82 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -57,6 +57,21 @@ class VectorsSuite extends SparkFunSuite with Logging { assert(vec.values === values) } + test("sparse vector construction with mismatched indices/values array") { + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0, 7.0, 9.0)) + } + intercept[IllegalArgumentException] { + Vectors.sparse(4, Array(1, 2, 3), Array(3.0, 5.0)) + } + } + + test("sparse vector construction with too many indices vs size") { + intercept[IllegalArgumentException] { + Vectors.sparse(3, Array(1, 2, 3, 4), Array(3.0, 5.0, 7.0, 9.0)) + } + } + test("dense to array") { val vec = Vectors.dense(arr).asInstanceOf[DenseVector] assert(vec.toArray.eq(arr)) From 81464f2a8243c6ae2a39bac7ebdc50d4f60af451 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 09:45:17 -0700 Subject: [PATCH 163/219] [MINOR] [MLLIB] fix doc for RegexTokenizer This is #7791 for Python. hhbyyh Author: Xiangrui Meng Closes #7798 from mengxr/regex-tok-py and squashes the following commits: baa2dcd [Xiangrui Meng] fix doc for RegexTokenizer --- python/pyspark/ml/feature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 86e654dd0779f..015e7a9d4900a 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -525,7 +525,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ A regex based tokenizer that extracts tokens either by using the provided regex pattern (in Java dialect) to split the text - (default) or repeatedly matching the regex (if gaps is true). + (default) or repeatedly matching the regex (if gaps is false). Optional parameters also allow filtering tokens using a minimal length. It returns an array of strings that can be empty. From 7492a33fdd074446c30c657d771a69932a00246d Mon Sep 17 00:00:00 2001 From: Yuu ISHIKAWA Date: Thu, 30 Jul 2015 10:00:27 -0700 Subject: [PATCH 164/219] [SPARK-9248] [SPARKR] Closing curly-braces should always be on their own line ### JIRA [[SPARK-9248] Closing curly-braces should always be on their own line - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9248) ## The result of `dev/lint-r` [The result of `dev/lint-r` for SPARK-9248 at the revistion:6175d6cfe795fbd88e3ee713fac375038a3993a8](https://gist.github.com/yu-iskw/96cadcea4ce664c41f81) Author: Yuu ISHIKAWA Closes #7795 from yu-iskw/SPARK-9248 and squashes the following commits: c8eccd3 [Yuu ISHIKAWA] [SPARK-9248][SparkR] Closing curly-braces should always be on their own line --- R/pkg/R/generics.R | 14 +++++++------- R/pkg/R/pairRDD.R | 4 ++-- R/pkg/R/sparkR.R | 9 ++++++--- R/pkg/inst/tests/test_sparkSQL.R | 6 ++++-- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 836e0175c391f..a3a121058e165 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -254,8 +254,10 @@ setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") # @rdname intersection # @export -setGeneric("intersection", function(x, other, numPartitions = 1) { - standardGeneric("intersection") }) +setGeneric("intersection", + function(x, other, numPartitions = 1) { + standardGeneric("intersection") + }) # @rdname keys # @export @@ -489,9 +491,7 @@ setGeneric("sample", #' @rdname sample #' @export setGeneric("sample_frac", - function(x, withReplacement, fraction, seed) { - standardGeneric("sample_frac") - }) + function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") }) #' @rdname saveAsParquetFile #' @export @@ -553,8 +553,8 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn #' @rdname withColumnRenamed #' @export -setGeneric("withColumnRenamed", function(x, existingCol, newCol) { - standardGeneric("withColumnRenamed") }) +setGeneric("withColumnRenamed", + function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) ###################### Column Methods ########################## diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index ebc6ff65e9d0f..83801d3209700 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -202,8 +202,8 @@ setMethod("partitionBy", packageNamesArr <- serialize(.sparkREnv$.packages, connection = NULL) - broadcastArr <- lapply(ls(.broadcastNames), function(name) { - get(name, .broadcastNames) }) + broadcastArr <- lapply(ls(.broadcastNames), + function(name) { get(name, .broadcastNames) }) jrdd <- getJRDD(x) # We create a PairwiseRRDD that extends RDD[(Int, Array[Byte])], diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 76c15875b50d5..e83104f116422 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -22,7 +22,8 @@ connExists <- function(env) { tryCatch({ exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]]) - }, error = function(err) { + }, + error = function(err) { return(FALSE) }) } @@ -153,7 +154,8 @@ sparkR.init <- function( .sparkREnv$backendPort <- backendPort tryCatch({ connectBackend("localhost", backendPort) - }, error = function(err) { + }, + error = function(err) { stop("Failed to connect JVM\n") }) @@ -264,7 +266,8 @@ sparkRHive.init <- function(jsc = NULL) { ssc <- callJMethod(sc, "sc") hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.HiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { stop("Spark SQL is not built with Hive support") }) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 62fe48a5d6c7b..d5db97248c770 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -112,7 +112,8 @@ test_that("create DataFrame from RDD", { df <- jsonFile(sqlContext, jsonPathNa) hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) sql(hiveCtx, "CREATE TABLE people (name string, age double, height float)") @@ -602,7 +603,8 @@ test_that("write.df() as parquet file", { test_that("test HiveContext", { hiveCtx <- tryCatch({ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, error = function(err) { + }, + error = function(err) { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") From c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 30 Jul 2015 10:04:30 -0700 Subject: [PATCH 165/219] [SPARK-9390][SQL] create a wrapper for array type Author: Wenchen Fan Closes #7724 from cloud-fan/array-data and squashes the following commits: d0408a1 [Wenchen Fan] fix python 661e608 [Wenchen Fan] rebase f39256c [Wenchen Fan] fix hive... 6dbfa6f [Wenchen Fan] fix hive again... 8cb8842 [Wenchen Fan] remove element type parameter from getArray 43e9816 [Wenchen Fan] fix mllib e719afc [Wenchen Fan] fix hive 4346290 [Wenchen Fan] address comment d4a38da [Wenchen Fan] remove sizeInBytes and add license 7e283e2 [Wenchen Fan] create a wrapper for array type --- .../apache/spark/mllib/linalg/Matrices.scala | 16 +-- .../apache/spark/mllib/linalg/Vectors.scala | 15 +-- .../expressions/SpecializedGetters.java | 2 + .../sql/catalyst/CatalystTypeConverters.scala | 29 +++-- .../spark/sql/catalyst/InternalRow.scala | 2 + .../catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/Cast.scala | 39 ++++-- .../expressions/codegen/CodeGenerator.scala | 28 ++-- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../expressions/collectionOperations.scala | 10 +- .../expressions/complexTypeCreator.scala | 20 ++- .../expressions/complexTypeExtractors.scala | 59 ++++++--- .../sql/catalyst/expressions/generators.scala | 4 +- .../expressions/stringOperations.scala | 12 +- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../apache/spark/sql/types/ArrayData.scala | 121 ++++++++++++++++++ .../spark/sql/types/GenericArrayData.scala | 59 +++++++++ .../sql/catalyst/expressions/CastSuite.scala | 21 ++- .../expressions/ComplexTypeSuite.scala | 2 +- .../spark/sql/execution/debug/package.scala | 4 +- .../spark/sql/execution/pythonUDFs.scala | 19 ++- .../sql/execution/stat/FrequentItems.scala | 4 +- .../apache/spark/sql/json/InferSchema.scala | 2 +- .../apache/spark/sql/json/JacksonParser.scala | 30 +++-- .../sql/parquet/CatalystRowConverter.scala | 2 +- .../spark/sql/parquet/ParquetConverter.scala | 3 +- .../sql/parquet/ParquetTableSupport.scala | 12 +- .../apache/spark/sql/JavaDataFrameSuite.java | 5 +- .../spark/sql/UserDefinedTypeSuite.scala | 8 +- .../spark/sql/sources/TableScanSuite.scala | 30 ++--- .../spark/sql/hive/HiveInspectors.scala | 28 ++-- .../hive/execution/ScriptTransformation.scala | 12 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 2 +- .../spark/sql/hive/HiveInspectorSuite.scala | 2 +- 34 files changed, 430 insertions(+), 181 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d82ba2456df1a..88914fa875990 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, sm.colPtrs.toSeq) - row.update(4, sm.rowIndices.toSeq) - row.update(5, sm.values.toSeq) + row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) + row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) + row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, dm.values.toSeq) + row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, dm.isTransposed) } row @@ -179,14 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray + val values = row.getArray(5).toArray.map(_.asInstanceOf[Double]) val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = - row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray - val rowIndices = - row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray + val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int]) + val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int]) new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 23c2c16d68d9a..89a1818db0d1d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -187,15 +187,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, indices.toSeq) - row.update(3, values.toSeq) + row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, values.toSeq) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row } } @@ -209,14 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = - row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray - val values = - row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray + val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int]) + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new SparseVector(size, indices, values) case 1 => - val values = - row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new DenseVector(values) } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index bc345dcd00e49..f7cea13688876 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.ArrayData; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -50,4 +51,5 @@ public interface SpecializedGetters { InternalRow getStruct(int ordinal, int numFields); + ArrayData getArray(int ordinal); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index d1d89a1f48329..22452c0f201ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -55,7 +55,6 @@ object CatalystTypeConverters { private def isWholePrimitive(dt: DataType): Boolean = dt match { case dt if isPrimitive(dt) => true - case ArrayType(elementType, _) => isWholePrimitive(elementType) case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) case _ => false } @@ -154,39 +153,41 @@ object CatalystTypeConverters { /** Converter for arrays, sequences, and Java iterables. */ private case class ArrayConverter( - elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], Seq[Any]] { + elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] { private[this] val elementConverter = getConverterForType(elementType) private[this] val isNoChange = isWholePrimitive(elementType) - override def toCatalystImpl(scalaValue: Any): Seq[Any] = { + override def toCatalystImpl(scalaValue: Any): ArrayData = { scalaValue match { - case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) - case s: Seq[_] => s.map(elementConverter.toCatalyst) + case a: Array[_] => + new GenericArrayData(a.map(elementConverter.toCatalyst)) + case s: Seq[_] => + new GenericArrayData(s.map(elementConverter.toCatalyst).toArray) case i: JavaIterable[_] => val iter = i.iterator - var convertedIterable: List[Any] = List() + val convertedIterable = scala.collection.mutable.ArrayBuffer.empty[Any] while (iter.hasNext) { val item = iter.next() - convertedIterable :+= elementConverter.toCatalyst(item) + convertedIterable += elementConverter.toCatalyst(item) } - convertedIterable + new GenericArrayData(convertedIterable.toArray) } } - override def toScala(catalystValue: Seq[Any]): Seq[Any] = { + override def toScala(catalystValue: ArrayData): Seq[Any] = { if (catalystValue == null) { null } else if (isNoChange) { - catalystValue + catalystValue.toArray() } else { - catalystValue.map(elementConverter.toScala) + catalystValue.toArray().map(elementConverter.toScala) } } override def toScalaImpl(row: InternalRow, column: Int): Seq[Any] = - toScala(row.get(column, ArrayType(elementType)).asInstanceOf[Seq[Any]]) + toScala(row.getArray(column)) } private case class MapConverter( @@ -402,9 +403,9 @@ object CatalystTypeConverters { case t: Timestamp => TimestampConverter.toCatalyst(t) case d: BigDecimal => BigDecimalConverter.toCatalyst(d) case d: JavaBigDecimal => BigDecimalConverter.toCatalyst(d) - case seq: Seq[Any] => seq.map(convertToCatalyst) + case seq: Seq[Any] => new GenericArrayData(seq.map(convertToCatalyst).toArray) case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) - case arr: Array[Any] => arr.map(convertToCatalyst) + case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) case m: Map[_, _] => m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index a5999e64ec554..486ba036548c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -76,6 +76,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters { override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs[InternalRow](ordinal, null) + override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null) + override def toString: String = s"[${this.mkString(",")}]" /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 371681b5d494f..45709c1c8f554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -65,7 +65,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) - val value = ctx.getColumn("i", dataType, ordinal) + val value = ctx.getValue("i", dataType, ordinal.toString) s""" boolean ${ev.isNull} = i.isNullAt($ordinal); $javaType ${ev.primitive} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8c01c13c9ccd5..43be11c48ae7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -363,7 +363,21 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = { val elementCast = cast(from.elementType, to.elementType) - buildCast[Seq[Any]](_, _.map(v => if (v == null) null else elementCast(v))) + // TODO: Could be faster? + buildCast[ArrayData](_, array => { + val length = array.numElements() + val values = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + values(i) = null + } else { + values(i) = elementCast(array.get(i)) + } + i += 1 + } + new GenericArrayData(values) + }) } private[this] def castMap(from: MapType, to: MapType): Any => Any = { @@ -789,37 +803,36 @@ case class Cast(child: Expression, dataType: DataType) private[this] def castArrayCode( from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = { val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx) - - val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") val fromElementPrim = ctx.freshName("fePrim") val toElementNull = ctx.freshName("teNull") val toElementPrim = ctx.freshName("tePrim") val size = ctx.freshName("n") val j = ctx.freshName("j") - val result = ctx.freshName("result") + val values = ctx.freshName("values") (c, evPrim, evNull) => s""" - final int $size = $c.size(); - final $arraySeqClass $result = new $arraySeqClass($size); + final int $size = $c.numElements(); + final Object[] $values = new Object[$size]; for (int $j = 0; $j < $size; $j ++) { - if ($c.apply($j) == null) { - $result.update($j, null); + if ($c.isNullAt($j)) { + $values[$j] = null; } else { boolean $fromElementNull = false; ${ctx.javaType(from.elementType)} $fromElementPrim = - (${ctx.boxedType(from.elementType)}) $c.apply($j); + ${ctx.getValue(c, from.elementType, j)}; ${castCode(ctx, fromElementPrim, fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)} if ($toElementNull) { - $result.update($j, null); + $values[$j] = null; } else { - $result.update($j, $toElementPrim); + $values[$j] = $toElementPrim; } } } - $evPrim = $result; + $evPrim = new $arrayClass($values); """ } @@ -891,7 +904,7 @@ case class Cast(child: Expression, dataType: DataType) $result.setNullAt($i); } else { $fromType $fromFieldPrim = - ${ctx.getColumn(tmpRow, from.fields(i).dataType, i)}; + ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)}; ${castCode(ctx, fromFieldPrim, fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)} if ($toFieldNull) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 092f4c9fb0bd2..c39e0df6fae2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -100,17 +100,18 @@ class CodeGenContext { } /** - * Returns the code to access a column in Row for a given DataType. + * Returns the code to access a value in `SpecializedGetters` for a given DataType. */ - def getColumn(row: String, dataType: DataType, ordinal: Int): String = { + def getValue(getter: String, dataType: DataType, ordinal: String): String = { val jt = javaType(dataType) dataType match { - case _ if isPrimitiveType(jt) => s"$row.get${primitiveTypeName(jt)}($ordinal)" - case StringType => s"$row.getUTF8String($ordinal)" - case BinaryType => s"$row.getBinary($ordinal)" - case CalendarIntervalType => s"$row.getInterval($ordinal)" - case t: StructType => s"$row.getStruct($ordinal, ${t.size})" - case _ => s"($jt)$row.get($ordinal)" + case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)" + case StringType => s"$getter.getUTF8String($ordinal)" + case BinaryType => s"$getter.getBinary($ordinal)" + case CalendarIntervalType => s"$getter.getInterval($ordinal)" + case t: StructType => s"$getter.getStruct($ordinal, ${t.size})" + case a: ArrayType => s"$getter.getArray($ordinal)" + case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter. } } @@ -152,8 +153,8 @@ class CodeGenContext { case StringType => "UTF8String" case CalendarIntervalType => "CalendarInterval" case _: StructType => "InternalRow" - case _: ArrayType => s"scala.collection.Seq" - case _: MapType => s"scala.collection.Map" + case _: ArrayType => "ArrayData" + case _: MapType => "scala.collection.Map" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" @@ -214,7 +215,9 @@ class CodeGenContext { case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" case NullType => "0" - case other => s"$c1.compare($c2)" + case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" + case _ => throw new IllegalArgumentException( + "cannot generate compare code for un-comparable type") } /** @@ -293,7 +296,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[UnsafeRow].getName, classOf[UTF8String].getName, classOf[Decimal].getName, - classOf[CalendarInterval].getName + classOf[CalendarInterval].getName, + classOf[ArrayData].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7be60114ce674..a662357fb6cf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -153,14 +153,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val nestedStructEv = GeneratedExpressionCode( code = "", isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" ) createCodeForStruct(ctx, nestedStructEv, st) case _ => GeneratedExpressionCode( code = "", isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getColumn(input.primitive, dt, i)}" + primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2d92dcf23a86e..1a00dbc254de1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,11 +27,15 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) override def nullSafeEval(value: Any): Int = child.dataType match { - case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size - case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size + case _: ArrayType => value.asInstanceOf[ArrayData].numElements() + case _: MapType => value.asInstanceOf[Map[Any, Any]].size } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();") + val sizeCall = child.dataType match { + case _: ArrayType => "numElements()" + case _: MapType => "size()" + } + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 0517050a45109..a145dfb4bbf08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,12 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.unsafe.types.UTF8String - -import scala.collection.mutable - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -46,25 +43,26 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false override def eval(input: InternalRow): Any = { - children.map(_.eval(input)) + new GenericArrayData(children.map(_.eval(input)).toArray) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = classOf[mutable.ArraySeq[Any]].getName + val arrayClass = classOf[GenericArrayData].getName s""" - boolean ${ev.isNull} = false; - $arraySeqClass ${ev.primitive} = new $arraySeqClass(${children.size}); + final boolean ${ev.isNull} = false; + final Object[] values = new Object[${children.size}]; """ + children.zipWithIndex.map { case (e, i) => val eval = e.gen(ctx) eval.code + s""" if (${eval.isNull}) { - ${ev.primitive}.update($i, null); + values[$i] = null; } else { - ${ev.primitive}.update($i, ${eval.primitive}); + values[$i] = ${eval.primitive}; } """ - }.mkString("\n") + }.mkString("\n") + + s"final ${ctx.javaType(dataType)} ${ev.primitive} = new $arrayClass(values);" } override def prettyName: String = "array" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 6331a9eb603ca..99393c9c76ab6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -57,7 +57,8 @@ object ExtractValue { case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) - GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), + ordinal, fields.length, containsNull) case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) @@ -118,7 +119,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) if ($eval.isNullAt($ordinal)) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${ctx.getColumn(eval, dataType, ordinal)}; + ${ev.primitive} = ${ctx.getValue(eval, dataType, ordinal.toString)}; } """ }) @@ -134,6 +135,7 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, + numFields: Int, containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) @@ -141,26 +143,45 @@ case class GetArrayStructFields( override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { - input.asInstanceOf[Seq[InternalRow]].map { row => - if (row == null) null else row.get(ordinal, field.dataType) + val array = input.asInstanceOf[ArrayData] + val length = array.numElements() + val result = new Array[Any](length) + var i = 0 + while (i < length) { + if (array.isNullAt(i)) { + result(i) = null + } else { + val row = array.getStruct(i, numFields) + if (row.isNullAt(ordinal)) { + result(i) = null + } else { + result(i) = row.get(ordinal, field.dataType) + } + } + i += 1 } + new GenericArrayData(result) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arraySeqClass = "scala.collection.mutable.ArraySeq" - // TODO: consider using Array[_] for ArrayType child to avoid - // boxing of primitives + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { s""" - final int n = $eval.size(); - final $arraySeqClass values = new $arraySeqClass(n); + final int n = $eval.numElements(); + final Object[] values = new Object[n]; for (int j = 0; j < n; j++) { - InternalRow row = (InternalRow) $eval.apply(j); - if (row != null && !row.isNullAt($ordinal)) { - values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + if ($eval.isNullAt(j)) { + values[j] = null; + } else { + final InternalRow row = $eval.getStruct(j, $numFields); + if (row.isNullAt($ordinal)) { + values[j] = null; + } else { + values[j] = ${ctx.getValue("row", field.dataType, ordinal.toString)}; + } } } - ${ev.primitive} = (${ctx.javaType(dataType)}) values; + ${ev.primitive} = new $arrayClass(values); """ }) } @@ -186,23 +207,23 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx protected override def nullSafeEval(value: Any, ordinal: Any): Any = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives - val baseValue = value.asInstanceOf[Seq[_]] + val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.size || index < 0) { + if (index >= baseValue.numElements() || index < 0) { null } else { - baseValue(index) + baseValue.get(index) } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - final int index = (int)$eval2; - if (index >= $eval1.size() || index < 0) { + final int index = (int) $eval2; + if (index >= $eval1.numElements() || index < 0) { ${ev.isNull} = true; } else { - ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply(index); + ${ev.primitive} = ${ctx.getValue(eval1, dataType, "index")}; } """ }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 2dbcf2830f876..8064235c64ef9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -121,8 +121,8 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { case ArrayType(_, _) => - val inputArray = child.eval(input).asInstanceOf[Seq[Any]] - if (inputArray == null) Nil else inputArray.map(v => InternalRow(v)) + val inputArray = child.eval(input).asInstanceOf[ArrayData] + if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v)) case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] if (inputMap == null) Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5b3a64a09679c..79c0ca56a8e79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -92,7 +92,7 @@ case class ConcatWs(children: Seq[Expression]) val flatInputs = children.flatMap { child => child.eval(input) match { case s: UTF8String => Iterator(s) - case arr: Seq[_] => arr.asInstanceOf[Seq[UTF8String]] + case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String]) case null => Iterator(null.asInstanceOf[UTF8String]) } } @@ -105,7 +105,7 @@ case class ConcatWs(children: Seq[Expression]) val evals = children.map(_.gen(ctx)) val inputs = evals.map { eval => - s"${eval.isNull} ? (UTF8String)null : ${eval.primitive}" + s"${eval.isNull} ? (UTF8String) null : ${eval.primitive}" }.mkString(", ") evals.map(_.code).mkString("\n") + s""" @@ -665,13 +665,15 @@ case class StringSplit(str: Expression, pattern: Expression) override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def nullSafeEval(string: Any, regex: Any): Any = { - string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq + val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + new GenericArrayData(strings.asInstanceOf[Array[Any]]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => - s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer( - java.util.Arrays.asList($str.split($pattern, -1)));""") + // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. + s"""${ev.primitive} = new $arrayClass($str.split($pattern, -1));""") } override def prettyName: String = "split" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 813c62009666c..29d706dcb39a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -312,7 +312,8 @@ object NullPropagation extends Rule[LogicalPlan] { case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _) => Literal.create(null, e.dataType) + case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => + Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case e @ Count(expr) if !expr.nullable => Count(Literal(1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala new file mode 100644 index 0000000000000..14a7285877622 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters + +abstract class ArrayData extends SpecializedGetters with Serializable { + // todo: remove this after we handle all types.(map type need special getter) + def get(ordinal: Int): Any + + def numElements(): Int + + // todo: need a more efficient way to iterate array type. + def toArray(): Array[Any] = { + val n = numElements() + val values = new Array[Any](n) + var i = 0 + while (i < n) { + if (isNullAt(i)) { + values(i) = null + } else { + values(i) = get(i) + } + i += 1 + } + values + } + + override def toString(): String = toArray.mkString("[", ",", "]") + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[ArrayData]) { + return false + } + + val other = o.asInstanceOf[ArrayData] + if (other eq null) { + return false + } + + val len = numElements() + if (len != other.numElements()) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = get(i) + val o2 = other.get(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numElements() + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + get(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala new file mode 100644 index 0000000000000..7992ba947c069 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval} + +class GenericArrayData(array: Array[Any]) extends ArrayData { + private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T] + + override def toArray(): Array[Any] = array + + override def get(ordinal: Int): Any = array(ordinal) + + override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null + + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + + override def getByte(ordinal: Int): Byte = getAs(ordinal) + + override def getShort(ordinal: Int): Short = getAs(ordinal) + + override def getInt(ordinal: Int): Int = getAs(ordinal) + + override def getLong(ordinal: Int): Long = getAs(ordinal) + + override def getFloat(ordinal: Int): Float = getAs(ordinal) + + override def getDouble(ordinal: Int): Double = getAs(ordinal) + + override def getDecimal(ordinal: Int): Decimal = getAs(ordinal) + + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + + override def numElements(): Int = array.length +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a517da9872852..4f35b653d73c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Timestamp, Date} import java.util.{TimeZone, Calendar} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -730,13 +731,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("complex casting") { val complex = Literal.create( - InternalRow( - Seq(UTF8String.fromString("123"), UTF8String.fromString("abc"), UTF8String.fromString("")), - Map( - UTF8String.fromString("a") -> UTF8String.fromString("123"), - UTF8String.fromString("b") -> UTF8String.fromString("abc"), - UTF8String.fromString("c") -> UTF8String.fromString("")), - InternalRow(0)), + Row( + Seq("123", "abc", ""), + Map("a" ->"123", "b" -> "abc", "c" -> ""), + Row(0)), StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false), nullable = true), @@ -756,13 +754,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("l", LongType, nullable = true))))))) assert(ret.resolved === true) - checkEvaluation(ret, InternalRow( + checkEvaluation(ret, Row( Seq(123, null, null), - Map( - UTF8String.fromString("a") -> true, - UTF8String.fromString("b") -> true, - UTF8String.fromString("c") -> false), - InternalRow(0L))) + Map("a" -> true, "b" -> true, "c" -> false), + Row(0L))) } test("case between string and interval") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 5de5ddce975d8..3fa246b69d1f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -110,7 +110,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { expr.dataType match { case ArrayType(StructType(fields), containsNull) => val field = fields.find(_.name == fieldName).get - GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index aeeb0e45270dd..f26f41fb75d57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -158,8 +158,8 @@ package object debug { case (row: InternalRow, StructType(fields)) => row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } - case (s: Seq[_], ArrayType(elemType, _)) => - s.foreach(typeCheck(_, elemType)) + case (a: ArrayData, ArrayType(elemType, _)) => + a.toArray().foreach(typeCheck(_, elemType)) case (m: Map[_, _], MapType(keyType, valueType, _)) => m.keys.foreach(typeCheck(_, keyType)) m.values.foreach(typeCheck(_, valueType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 3c38916fd7504..ef1c6e57dc08a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -134,8 +134,19 @@ object EvaluatePython { } new GenericInternalRowWithSchema(values, struct) - case (seq: Seq[Any], array: ArrayType) => - seq.map(x => toJava(x, array.elementType)).asJava + case (a: ArrayData, array: ArrayType) => + val length = a.numElements() + val values = new java.util.ArrayList[Any](length) + var i = 0 + while (i < length) { + if (a.isNullAt(i)) { + values.add(null) + } else { + values.add(toJava(a.get(i), array.elementType)) + } + i += 1 + } + values case (obj: Map[_, _], mt: MapType) => obj.map { case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) @@ -190,10 +201,10 @@ object EvaluatePython { case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => fromJava(e, elementType)}.toSeq + new GenericArrayData(c.map { e => fromJava(e, elementType)}.toArray) case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq + new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 78da2840dad69..9329148aa233c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{DataType, ArrayType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame} private[sql] object FrequentItems extends Logging { @@ -110,7 +110,7 @@ private[sql] object FrequentItems extends Logging { baseCounts } ) - val justItems = freqItems.map(m => m.baseMap.keys.toSeq) + val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_)) val resultRow = InternalRow(justItems : _*) // append frequent Items to the column name for easy debugging val outputCols = colInfo.map { v => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala index 0eb3b04007f8d..04ab5e2217882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala @@ -125,7 +125,7 @@ private[sql] object InferSchema { * Convert NullType to StringType and remove StructTypes with no fields */ private def canonicalizeType: DataType => Option[DataType] = { - case at@ArrayType(elementType, _) => + case at @ ArrayType(elementType, _) => for { canonicalType <- canonicalizeType(elementType) } yield { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 381e7ed54428f..1c309f8794ef3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -110,8 +110,13 @@ private[sql] object JacksonParser { case (START_OBJECT, st: StructType) => convertObject(factory, parser, st) + case (START_ARRAY, st: StructType) => + // SPARK-3308: support reading top level JSON arrays and take every element + // in such an array as a row + convertArray(factory, parser, st) + case (START_ARRAY, ArrayType(st, _)) => - convertList(factory, parser, st) + convertArray(factory, parser, st) case (START_OBJECT, ArrayType(st, _)) => // the business end of SPARK-3308: @@ -165,16 +170,16 @@ private[sql] object JacksonParser { builder.result() } - private def convertList( + private def convertArray( factory: JsonFactory, parser: JsonParser, - schema: DataType): Seq[Any] = { - val builder = Seq.newBuilder[Any] + elementType: DataType): ArrayData = { + val values = scala.collection.mutable.ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_ARRAY)) { - builder += convertField(factory, parser, schema) + values += convertField(factory, parser, elementType) } - builder.result() + new GenericArrayData(values.toArray) } private def parseJson( @@ -201,12 +206,15 @@ private[sql] object JacksonParser { val parser = factory.createParser(record) parser.nextToken() - // to support both object and arrays (see SPARK-3308) we'll start - // by converting the StructType schema to an ArrayType and let - // convertField wrap an object into a single value array when necessary. - convertField(factory, parser, ArrayType(schema)) match { + convertField(factory, parser, schema) match { case null => failedRecord(record) - case list: Seq[InternalRow @unchecked] => list + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray().map(_.asInstanceOf[InternalRow]) + } case _ => sys.error( s"Failed to parse record $record. Please make sure that each line of the file " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index e00bd90edb3dd..172db8362afb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -325,7 +325,7 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = elementConverter - override def end(): Unit = updater.set(currentArray) + override def end(): Unit = updater.set(new GenericArrayData(currentArray.toArray)) // NOTE: We can't reuse the mutable `ArrayBuffer` here and must instantiate a new buffer for the // next value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index ea51650fe9039..2332a36468dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.parquet import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.ArrayData // TODO Removes this while fixing SPARK-8848 private[sql] object CatalystConverter { @@ -32,7 +33,7 @@ private[sql] object CatalystConverter { val MAP_SCHEMA_NAME = "map" // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType[T] = Seq[T] + type ArrayScalaType[T] = ArrayData type StructScalaType[T] = InternalRow type MapScalaType[K, V] = Map[K, V] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 78ecfad1d57c6..79dd16b7b0c39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -146,15 +146,15 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo array: CatalystConverter.ArrayScalaType[_]): Unit = { val elementType = schema.elementType writer.startGroup() - if (array.size > 0) { + if (array.numElements() > 0) { if (schema.containsNull) { writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { + while (i < array.numElements()) { writer.startGroup() - if (array(i) != null) { + if (!array.isNullAt(i)) { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - writeValue(elementType, array(i)) + writeValue(elementType, array.get(i)) writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } writer.endGroup() @@ -164,8 +164,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo } else { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) var i = 0 - while (i < array.size) { - writeValue(elementType, array(i)) + while (i < array.numElements()) { + writeValue(elementType, array.get(i)) i = i + 1 } writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 72c42f4fe376b..9e61d06f4036e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -30,7 +30,6 @@ import scala.collection.JavaConversions; import scala.collection.Seq; -import scala.collection.mutable.Buffer; import java.io.Serializable; import java.util.Arrays; @@ -168,10 +167,10 @@ public void testCreateDataFrameFromJavaBeans() { for (int i = 0; i < result.length(); i++) { Assert.assertEquals(bean.getB()[i], result.apply(i)); } - Buffer outputBuffer = (Buffer) first.getJavaMap(2).get("hello"); + Seq outputBuffer = (Seq) first.getJavaMap(2).get("hello"); Assert.assertArrayEquals( bean.getC().get("hello"), - Ints.toArray(JavaConversions.bufferAsJavaList(outputBuffer))); + Ints.toArray(JavaConversions.seqAsJavaList(outputBuffer))); Seq d = first.getAs(3); Assert.assertEquals(bean.getD().size(), d.length()); for (int i = 0; i < d.length(); i++) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 45c9f06941c10..77ed4a9c0d5ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -47,17 +47,17 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(obj: Any): Seq[Double] = { + override def serialize(obj: Any): ArrayData = { obj match { case features: MyDenseVector => - features.data.toSeq + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) } } override def deserialize(datum: Any): MyDenseVector = { datum match { - case data: Seq[_] => - new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + case data: ArrayData => + new MyDenseVector(data.toArray.map(_.asInstanceOf[Double])) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 5e189c3563ca8..cfb03ff485b7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -67,12 +67,12 @@ case class AllDataTypesScan( override def schema: StructType = userSpecifiedSchema - override def needConversion: Boolean = false + override def needConversion: Boolean = true override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { i => - InternalRow( - UTF8String.fromString(s"str_$i"), + Row( + s"str_$i", s"str_$i".getBytes(), i % 2 == 0, i.toByte, @@ -81,19 +81,19 @@ case class AllDataTypesScan( i.toLong, i.toFloat, i.toDouble, - Decimal(new java.math.BigDecimal(i)), - Decimal(new java.math.BigDecimal(i)), - DateTimeUtils.fromJavaDate(new Date(1970, 1, 1)), - DateTimeUtils.fromJavaTimestamp(new Timestamp(20000 + i)), - UTF8String.fromString(s"varchar_$i"), + new java.math.BigDecimal(i), + new java.math.BigDecimal(i), + new Date(1970, 1, 1), + new Timestamp(20000 + i), + s"varchar_$i", Seq(i, i + 1), - Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), - Map(i -> UTF8String.fromString(i.toString)), - Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), - InternalRow(i, UTF8String.fromString(i.toString)), - InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), - InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) - }.asInstanceOf[RDD[Row]] + Seq(Map(s"str_$i" -> Row(i.toLong))), + Map(i -> i.toString), + Map(Map(s"str_$i" -> i.toFloat) -> Row(i.toLong)), + Row(i, i.toString), + Row(Seq(s"str_$i", s"str_${i + 1}"), + Row(Seq(new Date(1970, 1, i + 1))))) + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index f467500259c91..5926ef9aa388b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -52,9 +52,8 @@ import scala.collection.JavaConversions._ * java.sql.Timestamp * Complex Types => * Map: scala.collection.immutable.Map - * List: scala.collection.immutable.Seq - * Struct: - * [[org.apache.spark.sql.catalyst.InternalRow]] + * List: [[org.apache.spark.sql.types.ArrayData]] + * Struct: [[org.apache.spark.sql.catalyst.InternalRow]] * Union: NOT SUPPORTED YET * The Complex types plays as a container, which can hold arbitrary data types. * @@ -297,7 +296,10 @@ private[hive] trait HiveInspectors { }.toMap case li: StandardConstantListObjectInspector => // take the value from the list inspector object, rather than the input data - li.getWritableConstantValue.map(unwrap(_, li.getListElementObjectInspector)).toSeq + val values = li.getWritableConstantValue + .map(unwrap(_, li.getListElementObjectInspector)) + .toArray + new GenericArrayData(values) // if the value is null, we don't care about the object inspector type case _ if data == null => null case poi: VoidObjectInspector => null // always be null for void object inspector @@ -339,7 +341,10 @@ private[hive] trait HiveInspectors { } case li: ListObjectInspector => Option(li.getList(data)) - .map(_.map(unwrap(_, li.getListElementObjectInspector)).toSeq) + .map { l => + val values = l.map(unwrap(_, li.getListElementObjectInspector)).toArray + new GenericArrayData(values) + } .orNull case mi: MapObjectInspector => Option(mi.getMap(data)).map( @@ -391,7 +396,13 @@ private[hive] trait HiveInspectors { case loi: ListObjectInspector => val wrapper = wrapperFor(loi.getListElementObjectInspector) - (o: Any) => if (o != null) seqAsJavaList(o.asInstanceOf[Seq[_]].map(wrapper)) else null + (o: Any) => { + if (o != null) { + seqAsJavaList(o.asInstanceOf[ArrayData].toArray().map(wrapper)) + } else { + null + } + } case moi: MapObjectInspector => // The Predef.Map is scala.collection.immutable.Map. @@ -520,7 +531,7 @@ private[hive] trait HiveInspectors { case x: ListObjectInspector => val list = new java.util.ArrayList[Object] val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[Seq[_]].foreach { + a.asInstanceOf[ArrayData].toArray().foreach { v => list.add(wrap(v, x.getListElementObjectInspector, tpe)) } list @@ -634,7 +645,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[Seq[_]].foreach(v => list.add(wrap(v, listObjectInspector, dt))) + value.asInstanceOf[ArrayData].toArray() + .foreach(v => list.add(wrap(v, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 741c705e2a253..7e3342cc84c0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -176,13 +176,13 @@ case class ScriptTransformation( val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) + .map(CatalystTypeConverters.convertToCatalyst)) } else { - new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) - .asInstanceOf[Array[Any]]) + new GenericInternalRow( + prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) + .map(CatalystTypeConverters.convertToCatalyst)) } } else { val ret = deserialize() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 8732e9abf8d31..4a13022eddf60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -431,7 +431,7 @@ private[hive] case class HiveWindowFunction( // if pivotResult is true, we will get a Seq having the same size with the size // of the window frame. At here, we will return the result at the position of // index in the output buffer. - outputBuffer.asInstanceOf[Seq[Any]].get(index) + outputBuffer.asInstanceOf[ArrayData].get(index) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 0330013f5325e..f719f2e06ab63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -217,7 +217,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { test("wrap / unwrap Array Type") { val dt = ArrayType(dataTypes(0)) - val d = row(0) :: row(0) :: Nil + val d = new GenericArrayData(Array(row(0), row(0))) checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) checkValue(d, From 7bbf02f0bddefd19985372af79e906a38bc528b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Thu, 30 Jul 2015 18:14:08 +0100 Subject: [PATCH 166/219] [SPARK-9267] [CORE] Retire stringify(Partial)?Value from Accumulators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cc srowen Author: François Garillot Closes #7678 from huitseeker/master and squashes the following commits: 5e99f57 [François Garillot] [SPARK-9267][Core] Retire stringify(Partial)?Value from Accumulators --- core/src/main/scala/org/apache/spark/Accumulators.scala | 3 --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 6 ++---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 2f4fcac890eef..eb75f26718e19 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -341,7 +341,4 @@ private[spark] object Accumulators extends Logging { } } - def stringifyPartialValue(partialValue: Any): String = "%s".format(partialValue) - - def stringifyValue(value: Any): String = "%s".format(value) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index cdf6078421123..c4fa277c21254 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -916,11 +916,9 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && partialValue != acc.zero) { val name = acc.name.get - val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) - val stringValue = Accumulators.stringifyValue(acc.value) - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) + stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, s"${acc.value}") event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + AccumulableInfo(id, name, Some(s"$partialValue"), s"${acc.value}") } } } catch { From 5363ed71568c3e7c082146d654a9c669d692d894 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 10:30:37 -0700 Subject: [PATCH 167/219] [SPARK-9361] [SQL] Refactor new aggregation code to reduce the times of checking compatibility JIRA: https://issues.apache.org/jira/browse/SPARK-9361 Currently, we call `aggregate.Utils.tryConvert` in many places to check it the logical.Aggregate can be run with new aggregation. But looks like `aggregate.Utils.tryConvert` will cost considerable time to run. We should only call `tryConvert` once and keep it value in `logical.Aggregate` and reuse it. In `org.apache.spark.sql.execution.aggregate.Utils`, the codes involving with `tryConvert` should be moved to catalyst because it actually doesn't deal with execution details. Author: Liang-Chi Hsieh Closes #7677 from viirya/refactor_aggregate and squashes the following commits: babea30 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into refactor_aggregate 9a589d7 [Liang-Chi Hsieh] Fix scala style. 0a91329 [Liang-Chi Hsieh] Refactor new aggregation code to reduce the times to call tryConvert. --- .../expressions/aggregate/interfaces.scala | 4 +- .../expressions/aggregate/utils.scala | 167 ++++++++++++++++++ .../plans/logical/basicOperators.scala | 3 + .../spark/sql/execution/SparkStrategies.scala | 34 ++-- .../spark/sql/execution/aggregate/utils.scala | 144 --------------- 5 files changed, 188 insertions(+), 164 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 9fb7623172e78..d08f553cefe8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -42,7 +42,7 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. @@ -50,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode private[sql] case object Final extends AggregateMode /** - * An [[AggregateFunction2]] with [[Partial]] mode is used to evaluate this function directly + * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly * from original input rows without any partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala new file mode 100644 index 0000000000000..4a43318a95490 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.types.{StructType, MapType, ArrayType} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object Utils { + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { + val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { + case array: ArrayType => true + case map: MapType => true + case struct: StructType => true + case _ => false + } + + !hasComplexTypes + } + + private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = p.transformExpressionsDown { + case expressions.Average(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Average(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Count(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(child), + mode = aggregate.Complete, + isDistinct = false) + + // We do not support multiple COUNT DISTINCT columns for now. + case expressions.CountDistinct(children) if children.length == 1 => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(children.head), + mode = aggregate.Complete, + isDistinct = true) + + case expressions.First(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.First(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Last(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Last(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Max(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Max(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Min(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Min(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Sum(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.SumDistinct(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = true) + } + // Check if there is any expressions.AggregateExpression1 left. + // If so, we cannot convert this plan. + val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => + // For every expressions, check if it contains AggregateExpression1. + expr.find { + case agg: expressions.AggregateExpression1 => true + case other => false + }.isDefined + } + + // Check if there are multiple distinct columns. + val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val hasMultipleDistinctColumnSets = + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + true + } else { + false + } + + if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None + + case other => None + } + + def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { + // If the plan cannot be converted, we will do a final round check to see if the original + // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, + // we need to throw an exception. + val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg.aggregateFunction + } + }.distinct + if (aggregateFunction2s.nonEmpty) { + // For functions implemented based on the new interface, prepare a list of function names. + val invalidFunctions = { + if (aggregateFunction2s.length > 1) { + s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + + s"and ${aggregateFunction2s.head.nodeName} are" + } else { + s"${aggregateFunction2s.head.nodeName} is" + } + } + val errorMessage = + s"${invalidFunctions} implemented based on the new Aggregate Function " + + s"interface and it cannot be used with functions implemented based on " + + s"the old Aggregate Function interface." + throw new AnalysisException(errorMessage) + } + } + + def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate => + val converted = doConvert(p) + if (converted.isDefined) { + converted + } else { + checkInvalidAggregateFunction2(p) + None + } + case other => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index ad5af19578f33..a67f8de6b733a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Utils import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -219,6 +220,8 @@ case class Aggregate( expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions } + lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this) + override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f3ef066528ff8..52a9b02d373c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -193,11 +193,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } - def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = { - aggregate.Utils.tryConvert( - plan, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled).isDefined + def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match { + case a: logical.Aggregate => + if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) { + a.newAggregation.isDefined + } else { + Utils.checkInvalidAggregateFunction2(a) + false + } + case _ => false } def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall { @@ -217,12 +221,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p: logical.Aggregate => - val converted = - aggregate.Utils.tryConvert( - p, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled) + case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 && + sqlContext.conf.codegenEnabled => + val converted = p.newAggregation converted match { case None => Nil // Cannot convert to new aggregation code path. case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => @@ -377,17 +378,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ logical.Expand(_, _, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case a @ logical.Aggregate(group, agg, child) => { - val useNewAggregation = - aggregate.Utils.tryConvert( - a, - sqlContext.conf.useSqlAggregate2, - sqlContext.conf.codegenEnabled).isDefined - if (useNewAggregation) { + val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled + if (useNewAggregation && a.newAggregation.isDefined) { // If this logical.Aggregate can be planned to use new aggregation code path // (i.e. it can be planned by the Strategy Aggregation), we will not use the old // aggregation code path. Nil } else { + Utils.checkInvalidAggregateFunction2(a) execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 6549c87752a7d..03635baae4a5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -29,150 +29,6 @@ import org.apache.spark.sql.types.{StructType, MapType, ArrayType} * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { - val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { - case array: ArrayType => true - case map: MapType => true - case struct: StructType => true - case _ => false - } - - !hasComplexTypes - } - - private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - } - // Check if there is any expressions.AggregateExpression1 left. - // If so, we cannot convert this plan. - val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => - // For every expressions, check if it contains AggregateExpression1. - expr.find { - case agg: expressions.AggregateExpression1 => true - case other => false - }.isDefined - } - - // Check if there are multiple distinct columns. - val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) - val hasMultipleDistinctColumnSets = - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - true - } else { - false - } - - if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None - - case other => None - } - - private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { - // If the plan cannot be converted, we will do a final round check to if the original - // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, - // we need to throw an exception. - val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg.aggregateFunction - } - }.distinct - if (aggregateFunction2s.nonEmpty) { - // For functions implemented based on the new interface, prepare a list of function names. - val invalidFunctions = { - if (aggregateFunction2s.length > 1) { - s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + - s"and ${aggregateFunction2s.head.nodeName} are" - } else { - s"${aggregateFunction2s.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} implemented based on the new Aggregate Function " + - s"interface and it cannot be used with functions implemented based on " + - s"the old Aggregate Function interface." - throw new AnalysisException(errorMessage) - } - } - - def tryConvert( - plan: LogicalPlan, - useNewAggregation: Boolean, - codeGenEnabled: Boolean): Option[Aggregate] = plan match { - case p: Aggregate if useNewAggregation && codeGenEnabled => - val converted = tryConvert(p) - if (converted.isDefined) { - converted - } else { - checkInvalidAggregateFunction2(p) - None - } - case p: Aggregate => - checkInvalidAggregateFunction2(p) - None - case other => None - } - def planAggregateWithoutDistinct( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[AggregateExpression2], From e53534655d6198e5b8a507010d26c7b4c4e7f1fd Mon Sep 17 00:00:00 2001 From: Mridul Muralidharan Date: Thu, 30 Jul 2015 10:37:53 -0700 Subject: [PATCH 168/219] [SPARK-8297] [YARN] Scheduler backend is not notified in case node fails in YARN This change adds code to notify the scheduler backend when a container dies in YARN. Author: Mridul Muralidharan Author: Marcelo Vanzin Closes #7431 from vanzin/SPARK-8297 and squashes the following commits: 471e4a0 [Marcelo Vanzin] Fix unit test after merge. d4adf4e [Marcelo Vanzin] Merge branch 'master' into SPARK-8297 3b262e8 [Marcelo Vanzin] Merge branch 'master' into SPARK-8297 537da6f [Marcelo Vanzin] Make an expected log less scary. 04dc112 [Marcelo Vanzin] Use driver <-> AM communication to send "remove executor" request. 8855b97 [Marcelo Vanzin] Merge remote-tracking branch 'mridul/fix_yarn_scheduler_bug' into SPARK-8297 687790f [Mridul Muralidharan] Merge branch 'fix_yarn_scheduler_bug' of github.com:mridulm/spark into fix_yarn_scheduler_bug e1b0067 [Mridul Muralidharan] Fix failing testcase, fix merge issue from our 1.3 -> master 9218fcc [Mridul Muralidharan] Fix failing testcase 362d64a [Mridul Muralidharan] Merge branch 'fix_yarn_scheduler_bug' of github.com:mridulm/spark into fix_yarn_scheduler_bug 62ad0cc [Mridul Muralidharan] Merge branch 'fix_yarn_scheduler_bug' of github.com:mridulm/spark into fix_yarn_scheduler_bug bbf8811 [Mridul Muralidharan] Merge branch 'fix_yarn_scheduler_bug' of github.com:mridulm/spark into fix_yarn_scheduler_bug 9ee1307 [Mridul Muralidharan] Fix SPARK-8297 a3a0f01 [Mridul Muralidharan] Fix SPARK-8297 --- .../CoarseGrainedSchedulerBackend.scala | 2 +- .../cluster/YarnSchedulerBackend.scala | 2 ++ .../spark/deploy/yarn/ApplicationMaster.scala | 22 +++++++++---- .../spark/deploy/yarn/YarnAllocator.scala | 32 +++++++++++++++---- .../spark/deploy/yarn/YarnRMClient.scala | 5 ++- .../deploy/yarn/YarnAllocatorSuite.scala | 29 +++++++++++++++++ 6 files changed, 77 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 660702f6e6fd0..bd89160af4ffa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -241,7 +241,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.executorLost(executorId, SlaveLost(reason)) listenerBus.post( SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) - case None => logError(s"Asked to remove non-existent executor $executorId") + case None => logInfo(s"Asked to remove non-existent executor $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 074282d1be37d..044f6288fabdd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -109,6 +109,8 @@ private[spark] abstract class YarnSchedulerBackend( case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) + case RemoveExecutor(executorId, reason) => + removeExecutor(executorId, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 44acc7374d024..1d67b3ebb51b7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -229,7 +229,11 @@ private[spark] class ApplicationMaster( sparkContextRef.compareAndSet(sc, null) } - private def registerAM(_rpcEnv: RpcEnv, uiAddress: String, securityMgr: SecurityManager) = { + private def registerAM( + _rpcEnv: RpcEnv, + driverRef: RpcEndpointRef, + uiAddress: String, + securityMgr: SecurityManager) = { val sc = sparkContextRef.get() val appId = client.getAttemptId().getApplicationId().toString() @@ -246,6 +250,7 @@ private[spark] class ApplicationMaster( RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), CoarseGrainedSchedulerBackend.ENDPOINT_NAME) allocator = client.register(driverUrl, + driverRef, yarnConf, _sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), @@ -262,17 +267,20 @@ private[spark] class ApplicationMaster( * * In cluster mode, the AM and the driver belong to same process * so the AMEndpoint need not monitor lifecycle of the driver. + * + * @return A reference to the driver's RPC endpoint. */ private def runAMEndpoint( host: String, port: String, - isClusterMode: Boolean): Unit = { + isClusterMode: Boolean): RpcEndpointRef = { val driverEndpoint = rpcEnv.setupEndpointRef( SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = rpcEnv.setupEndpoint("YarnAM", new AMEndpoint(rpcEnv, driverEndpoint, isClusterMode)) + driverEndpoint } private def runDriver(securityMgr: SecurityManager): Unit = { @@ -290,11 +298,11 @@ private[spark] class ApplicationMaster( "Timed out waiting for SparkContext.") } else { rpcEnv = sc.env.rpcEnv - runAMEndpoint( + val driverRef = runAMEndpoint( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port"), isClusterMode = true) - registerAM(rpcEnv, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) + registerAM(rpcEnv, driverRef, sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } } @@ -302,9 +310,9 @@ private[spark] class ApplicationMaster( private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { val port = sparkConf.getInt("spark.yarn.am.port", 0) rpcEnv = RpcEnv.create("sparkYarnAM", Utils.localHostName, port, sparkConf, securityMgr) - waitForSparkDriver() + val driverRef = waitForSparkDriver() addAmIpFilter() - registerAM(rpcEnv, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) + registerAM(rpcEnv, driverRef, sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() @@ -428,7 +436,7 @@ private[spark] class ApplicationMaster( } } - private def waitForSparkDriver(): Unit = { + private def waitForSparkDriver(): RpcEndpointRef = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false val hostport = args.userArgs(0) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 6c103394af098..59caa787b6e20 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -36,6 +36,9 @@ import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ /** * YarnAllocator is charged with requesting containers from the YARN ResourceManager and deciding @@ -52,6 +55,7 @@ import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ */ private[yarn] class YarnAllocator( driverUrl: String, + driverRef: RpcEndpointRef, conf: Configuration, sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], @@ -88,6 +92,9 @@ private[yarn] class YarnAllocator( // Visible for testing. private[yarn] val executorIdToContainer = new HashMap[String, Container] + private var numUnexpectedContainerRelease = 0L + private val containerIdToExecutorId = new HashMap[ContainerId, String] + // Executor memory in MB. protected val executorMemory = args.executorMemory // Additional memory overhead. @@ -184,6 +191,7 @@ private[yarn] class YarnAllocator( def killExecutor(executorId: String): Unit = synchronized { if (executorIdToContainer.contains(executorId)) { val container = executorIdToContainer.remove(executorId).get + containerIdToExecutorId.remove(container.getId) internalReleaseContainer(container) numExecutorsRunning -= 1 } else { @@ -383,6 +391,7 @@ private[yarn] class YarnAllocator( logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) executorIdToContainer(executorId) = container + containerIdToExecutorId(container.getId) = executorId val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, new HashSet[ContainerId]) @@ -413,12 +422,8 @@ private[yarn] class YarnAllocator( private[yarn] def processCompletedContainers(completedContainers: Seq[ContainerStatus]): Unit = { for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId - - if (releasedContainers.contains(containerId)) { - // Already marked the container for release, so remove it from - // `releasedContainers`. - releasedContainers.remove(containerId) - } else { + val alreadyReleased = releasedContainers.remove(containerId) + if (!alreadyReleased) { // Decrement the number of executors running. The next iteration of // the ApplicationMaster's reporting thread will take care of allocating. numExecutorsRunning -= 1 @@ -460,6 +465,18 @@ private[yarn] class YarnAllocator( allocatedContainerToHostMap.remove(containerId) } + + containerIdToExecutorId.remove(containerId).foreach { eid => + executorIdToContainer.remove(eid) + + if (!alreadyReleased) { + // The executor could have gone away (like no route to host, node failure, etc) + // Notify backend about the failure of the executor + numUnexpectedContainerRelease += 1 + driverRef.send(RemoveExecutor(eid, + s"Yarn deallocated the executor $eid (container $containerId)")) + } + } } } @@ -467,6 +484,9 @@ private[yarn] class YarnAllocator( releasedContainers.add(container.getId()) amClient.releaseAssignedContainer(container.getId()) } + + private[yarn] def getNumUnexpectedContainerRelease = numUnexpectedContainerRelease + } private object YarnAllocator { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 7f533ee55e8bb..4999f9c06210a 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.Utils @@ -56,6 +57,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg */ def register( driverUrl: String, + driverRef: RpcEndpointRef, conf: YarnConfiguration, sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], @@ -73,7 +75,8 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(driverUrl, conf, sparkConf, amClient, getAttemptId(), args, securityMgr) + new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args, + securityMgr) } /** diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 37a789fcd375b..58318bf9bcc08 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -27,10 +27,14 @@ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.mockito.Mockito._ + import org.apache.spark.{SecurityManager, SparkFunSuite} import org.apache.spark.SparkConf import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo class MockResolver extends DNSToSwitchMapping { @@ -90,6 +94,7 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter "--class", "SomeClass") new YarnAllocator( "not used", + mock(classOf[RpcEndpointRef]), conf, sparkConf, rmClient, @@ -230,6 +235,30 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter handler.getNumPendingAllocate should be (1) } + test("lost executor removed from backend") { + val handler = createAllocator(4) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (4) + + val container1 = createContainer("host1") + val container2 = createContainer("host2") + handler.handleAllocatedContainers(Array(container1, container2)) + + handler.requestTotalExecutorsWithPreferredLocalities(2, 0, Map()) + + val statuses = Seq(container1, container2).map { c => + ContainerStatus.newInstance(c.getId(), ContainerState.COMPLETE, "Failed", -1) + } + handler.updateResourceRequests() + handler.processCompletedContainers(statuses.toSeq) + handler.updateResourceRequests() + handler.getNumExecutorsRunning should be (0) + handler.getNumPendingAllocate should be (2) + handler.getNumExecutorsFailed should be (2) + handler.getNumUnexpectedContainerRelease should be (2) + } + test("memory exceeded diagnostic regexes") { val diagnostics = "Container [pid=12465,containerID=container_1412887393566_0003_01_000002] is running " + From ab78b1d2a6ce26833ea3878a63921efd805a3737 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 30 Jul 2015 10:40:04 -0700 Subject: [PATCH 169/219] [SPARK-9388] [YARN] Make executor info log messages easier to read. Author: Marcelo Vanzin Closes #7706 from vanzin/SPARK-9388 and squashes the following commits: 028b990 [Marcelo Vanzin] Single log statement. 3c5fb6a [Marcelo Vanzin] YARN not Yarn. 5bcd7a0 [Marcelo Vanzin] [SPARK-9388] [yarn] Make executor info log messages easier to read. --- .../scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- .../apache/spark/deploy/yarn/ExecutorRunnable.scala | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index bc28ce5eeae72..4ac3397f1ad28 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -767,7 +767,7 @@ private[spark] class Client( amContainer.setCommands(printableCommands) logDebug("===============================================================================") - logDebug("Yarn AM launch context:") + logDebug("YARN AM launch context:") logDebug(s" user class: ${Option(args.userClass).getOrElse("N/A")}") logDebug(" env:") launchEnv.foreach { case (k, v) => logDebug(s" $k -> $v") } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 78e27fb7f3337..52580deb372c2 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -86,10 +86,17 @@ class ExecutorRunnable( val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, appId, localResources) - logInfo(s"Setting up executor with environment: $env") - logInfo("Setting up executor with commands: " + commands) - ctx.setCommands(commands) + logInfo(s""" + |=============================================================================== + |YARN executor launch context: + | env: + |${env.map { case (k, v) => s" $k -> $v\n" }.mkString} + | command: + | ${commands.mkString(" ")} + |=============================================================================== + """.stripMargin) + ctx.setCommands(commands) ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) // If external shuffle service is enabled, register with the Yarn shuffle service already From 520ec0ff9db75267f627dc4615b2316a1a3d44d7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 30 Jul 2015 10:45:32 -0700 Subject: [PATCH 170/219] [SPARK-8850] [SQL] Enable Unsafe mode by default This pull request enables Unsafe mode by default in Spark SQL. In order to do this, we had to fix a number of small issues: **List of fixed blockers**: - [x] Make some default buffer sizes configurable so that HiveCompatibilitySuite can run properly (#7741). - [x] Memory leak on grouped aggregation of empty input (fixed by #7560 to fix this) - [x] Update planner to also check whether codegen is enabled before planning unsafe operators. - [x] Investigate failing HiveThriftBinaryServerSuite test. This turns out to be caused by a ClassCastException that occurs when Exchange tries to apply an interpreted RowOrdering to an UnsafeRow when range partitioning an RDD. This could be fixed by #7408, but a shorter-term fix is to just skip the Unsafe exchange path when RangePartitioner is used. - [x] Memory leak exceptions masking exceptions that actually caused tasks to fail (will be fixed by #7603). - [x] ~~https://issues.apache.org/jira/browse/SPARK-9162, to implement code generation for ScalaUDF. This is necessary for `UDFSuite` to pass. For now, I've just ignored this test in order to try to find other problems while we wait for a fix.~~ This is no longer necessary as of #7682. - [x] Memory leaks from Limit after UnsafeExternalSort cause the memory leak detector to fail tests. This is a huge problem in the HiveCompatibilitySuite (fixed by f4ac642a4e5b2a7931c5e04e086bb10e263b1db6). - [x] Tests in `AggregationQuerySuite` are failing due to NaN-handling issues in UnsafeRow, which were fixed in #7736. - [x] `org.apache.spark.sql.ColumnExpressionSuite.rand` needs to be updated so that the planner check also matches `TungstenProject`. - [x] After having lowered the buffer sizes to 4MB so that most of HiveCompatibilitySuite runs: - [x] Wrong answer in `join_1to1` (fixed by #7680) - [x] Wrong answer in `join_nulls` (fixed by #7680) - [x] Managed memory OOM / leak in `lateral_view` - [x] Seems to hang indefinitely in `partcols1`. This might be a deadlock in script transformation or a bug in error-handling code? The hang was fixed by #7710. - [x] Error while freeing memory in `partcols1`: will be fixed by #7734. - [x] After fixing the `partcols1` hang, it appears that a number of later tests have issues as well. - [x] Fix thread-safety bug in codegen fallback expression evaluation (#7759). Author: Josh Rosen Closes #7564 from JoshRosen/unsafe-by-default and squashes the following commits: 83c0c56 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-by-default f4cc859 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-by-default 963f567 [Josh Rosen] Reduce buffer size for R tests d6986de [Josh Rosen] Lower page size in PySpark tests 013b9da [Josh Rosen] Also match TungstenProject in checkNumProjects 5d0b2d3 [Josh Rosen] Add task completion callback to avoid leak in limit after sort ea250da [Josh Rosen] Disable unsafe Exchange path when RangePartitioning is used 715517b [Josh Rosen] Enable Unsafe by default --- R/run-tests.sh | 2 +- .../unsafe/sort/UnsafeExternalSorter.java | 14 +++++++++++++ python/pyspark/java_gateway.py | 6 +++++- .../scala/org/apache/spark/sql/SQLConf.scala | 2 +- .../apache/spark/sql/execution/Exchange.scala | 7 ++++++- .../spark/sql/ColumnExpressionSuite.scala | 3 ++- .../execution/UnsafeExternalSortSuite.scala | 20 +------------------ 7 files changed, 30 insertions(+), 24 deletions(-) diff --git a/R/run-tests.sh b/R/run-tests.sh index e82ad0ba2cd06..18a1e13bdc655 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 $FWDIR/../bin/sparkR --conf spark.buffer.pageSize=4m --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) if [[ $FAILED != 0 ]]; then diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index c21990f4e4778..866e0b4151577 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -20,6 +20,9 @@ import java.io.IOException; import java.util.LinkedList; +import scala.runtime.AbstractFunction0; +import scala.runtime.BoxedUnit; + import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -90,6 +93,17 @@ public UnsafeExternalSorter( this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); initializeForWriting(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addOnCompleteCallback(new AbstractFunction0() { + @Override + public BoxedUnit apply() { + freeMemory(); + return null; + } + }); } // TODO: metrics tracking + integration with shuffle write metrics diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 90cd342a6cf7f..60be85e53e2aa 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -52,7 +52,11 @@ def launch_gateway(): script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") if os.environ.get("SPARK_TESTING"): - submit_args = "--conf spark.ui.enabled=false " + submit_args + submit_args = ' '.join([ + "--conf spark.ui.enabled=false", + "--conf spark.buffer.pageSize=4mb", + submit_args + ]) command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2564bbd2077bf..6644e85d4a037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -229,7 +229,7 @@ private[spark] object SQLConf { " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, use the new optimized Tungsten physical execution backend.") val DIALECT = stringConf( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 41a0c519ba527..70e5031fb63c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -47,7 +47,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = { + // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to + // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to + // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. + !newPartitioning.isInstanceOf[RangePartitioning] + } /** * Determines whether records must be defensively copied before being sent to the shuffle. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 5c1102410879a..eb64684ae0fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.SQLTestUtils @@ -538,6 +538,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { case project: Project => project + case tungstenProject: TungstenProject => tungstenProject } assert(projects.size === expectedNumProjects) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 7a4baa9e4a49d..138636b0c65b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -36,10 +36,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) } - ignore("sort followed by limit should not leak memory") { - // TODO: this test is going to fail until we implement a proper iterator interface - // with a close() method. - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") + test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), @@ -48,21 +45,6 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) } - test("sort followed by limit") { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") - try { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } finally { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") - - } - } - test("sorting does not crash for large inputs") { val sortOrder = 'a.asc :: Nil val stringLength = 1024 * 1024 * 2 From 06b6a074fb224b3fe23922bdc89fc5f7c2ffaaf6 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Thu, 30 Jul 2015 10:46:26 -0700 Subject: [PATCH 171/219] [SPARK-9437] [CORE] avoid overflow in SizeEstimator https://issues.apache.org/jira/browse/SPARK-9437 Author: Imran Rashid Closes #7750 from squito/SPARK-9437_size_estimator_overflow and squashes the following commits: 29493f1 [Imran Rashid] prevent another potential overflow bc1cb82 [Imran Rashid] avoid overflow --- .../main/scala/org/apache/spark/util/SizeEstimator.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 7d84468f62ab1..14b1f2a17e707 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -217,10 +217,10 @@ object SizeEstimator extends Logging { var arrSize: Long = alignSize(objectSize + INT_SIZE) if (elementClass.isPrimitive) { - arrSize += alignSize(length * primitiveSize(elementClass)) + arrSize += alignSize(length.toLong * primitiveSize(elementClass)) state.size += arrSize } else { - arrSize += alignSize(length * pointerSize) + arrSize += alignSize(length.toLong * pointerSize) state.size += arrSize if (length <= ARRAY_SIZE_FOR_SAMPLING) { @@ -336,7 +336,7 @@ object SizeEstimator extends Logging { // hg.openjdk.java.net/jdk8/jdk8/hotspot/file/tip/src/share/vm/classfile/classFileParser.cpp var alignedSize = shellSize for (size <- fieldSizes if sizeCount(size) > 0) { - val count = sizeCount(size) + val count = sizeCount(size).toLong // If there are internal gaps, smaller field can fit in. alignedSize = math.max(alignedSize, alignSizeUp(shellSize, size) + size * count) shellSize += size * count From 6d94bf6ac10ac851636c62439f8f2737f3526a2a Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 11:13:15 -0700 Subject: [PATCH 172/219] [SPARK-8174] [SPARK-8175] [SQL] function unix_timestamp, from_unixtime unix_timestamp(): long Gets current Unix timestamp in seconds. unix_timestamp(string|date): long Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), using the default timezone and the default locale, return null if fail: unix_timestamp('2009-03-20 11:30:01') = 1237573801 unix_timestamp(string date, string pattern): long Convert time string with given pattern (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) to Unix time stamp (in seconds), return null if fail: unix_timestamp('2009-03-20', 'yyyy-MM-dd') = 1237532400. from_unixtime(bigint unixtime[, string format]): string Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string representing the timestamp of that moment in the current system time zone in the format of "1970-01-01 00:00:00". Jira: https://issues.apache.org/jira/browse/SPARK-8174 https://issues.apache.org/jira/browse/SPARK-8175 Author: Daoyuan Wang Closes #7644 from adrian-wang/udfunixtime and squashes the following commits: 2fe20c4 [Daoyuan Wang] util.Date ea2ec16 [Daoyuan Wang] use util.Date for better performance a2cf929 [Daoyuan Wang] doc return null instead of 0 f6f070a [Daoyuan Wang] address comments from davies 6a4cbb3 [Daoyuan Wang] temp 56ded53 [Daoyuan Wang] rebase and address comments 14a8b37 [Daoyuan Wang] function unix_timestamp, from_unixtime --- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/datetimeFunctions.scala | 219 +++++++++++++++++- .../expressions/DateExpressionsSuite.scala | 59 ++++- .../org/apache/spark/sql/functions.scala | 42 ++++ .../apache/spark/sql/DateFunctionsSuite.scala | 56 +++++ 5 files changed, 374 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 378df4f57d9e2..d663f12bc6d0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -211,6 +211,7 @@ object FunctionRegistry { expression[DayOfMonth]("day"), expression[DayOfYear]("dayofyear"), expression[DayOfMonth]("dayofmonth"), + expression[FromUnixTime]("from_unixtime"), expression[Hour]("hour"), expression[LastDay]("last_day"), expression[Minute]("minute"), @@ -218,6 +219,7 @@ object FunctionRegistry { expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), + expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index efecb771f2f5d..a5e6249e438d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.Date import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} @@ -28,6 +27,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import scala.util.Try + /** * Returns the current date at the start of query evaluation. * All calls of current_date within the same query return the same value. @@ -236,20 +237,232 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx override protected def nullSafeEval(timestamp: Any, format: Any): Any = { val sdf = new SimpleDateFormat(format.toString) - UTF8String.fromString(sdf.format(new Date(timestamp.asInstanceOf[Long] / 1000))) + UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val sdf = classOf[SimpleDateFormat].getName defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString((new $sdf($format.toString())) - .format(new java.sql.Date($timestamp / 1000)))""" + .format(new java.util.Date($timestamp / 1000)))""" }) } override def prettyName: String = "date_format" } +/** + * Converts time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), returns null if fail. + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + * If the second parameter is missing, use "yyyy-MM-dd HH:mm:ss". + * If no parameters provided, the first parameter will be current_timestamp. + * If the first parameter is a Date or Timestamp instead of String, we will ignore the + * second parameter. + */ +case class UnixTimestamp(timeExp: Expression, format: Expression) + extends BinaryExpression with ExpectsInputTypes { + + override def left: Expression = timeExp + override def right: Expression = format + + def this(time: Expression) = { + this(time, Literal("yyyy-MM-dd HH:mm:ss")) + } + + def this() = { + this(CurrentTimestamp()) + } + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringType, DateType, TimestampType), StringType) + + override def dataType: DataType = LongType + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val t = left.eval(input) + if (t == null) { + null + } else { + left.dataType match { + case DateType => + DateTimeUtils.daysToMillis(t.asInstanceOf[Int]) / 1000L + case TimestampType => + t.asInstanceOf[Long] / 1000000L + case StringType if right.foldable => + if (constFormat != null) { + Try(new SimpleDateFormat(constFormat.toString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } else { + null + } + case StringType => + val f = format.eval(input) + if (f == null) { + null + } else { + val formatString = f.asInstanceOf[UTF8String].toString + Try(new SimpleDateFormat(formatString).parse( + t.asInstanceOf[UTF8String].toString).getTime / 1000L).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + left.dataType match { + case StringType if right.foldable => + val sdf = classOf[SimpleDateFormat].getName + val fString = if (constFormat == null) null else constFormat.toString + val formatter = ctx.freshName("formatter") + if (fString == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + $sdf $formatter = new $sdf("$fString"); + ${ev.primitive} = + $formatter.parse(${eval1.primitive}.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + case StringType => + val sdf = classOf[SimpleDateFormat].getName + nullSafeCodeGen(ctx, ev, (string, format) => { + s""" + try { + ${ev.primitive} = + (new $sdf($format.toString())).parse($string.toString()).getTime() / 1000L; + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + """ + }) + case TimestampType => + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = ${eval1.primitive} / 1000000L; + } + """ + case DateType => + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val eval1 = left.gen(ctx) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.daysToMillis(${eval1.primitive}) / 1000L; + } + """ + } + } +} + +/** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. If the format is missing, using format like "1970-01-01 00:00:00". + * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. + */ +case class FromUnixTime(sec: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = sec + override def right: Expression = format + + def this(unix: Expression) = { + this(unix, Literal("yyyy-MM-dd HH:mm:ss")) + } + + override def dataType: DataType = StringType + + override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) + + private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] + + override def eval(input: InternalRow): Any = { + val time = left.eval(input) + if (time == null) { + null + } else { + if (format.foldable) { + if (constFormat == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat(constFormat.toString).format( + new java.util.Date(time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } else { + val f = format.eval(input) + if (f == null) { + null + } else { + Try(UTF8String.fromString(new SimpleDateFormat( + f.asInstanceOf[UTF8String].toString).format(new java.util.Date( + time.asInstanceOf[Long] * 1000L)))).getOrElse(null) + } + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val sdf = classOf[SimpleDateFormat].getName + if (format.foldable) { + if (constFormat == null) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val t = left.gen(ctx) + s""" + ${t.code} + boolean ${ev.isNull} = ${t.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + try { + ${ev.primitive} = UTF8String.fromString(new $sdf("${constFormat.toString}").format( + new java.util.Date(${t.primitive} * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + } + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (seconds, f) => { + s""" + try { + ${ev.primitive} = UTF8String.fromString((new $sdf($f.toString())).format( + new java.util.Date($seconds * 1000L))); + } catch (java.lang.Throwable e) { + ${ev.isNull} = true; + }""".stripMargin + }) + } + } + +} + /** * Returns the last day of the month which the date belongs to. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index aca8d6eb3500c..e1387f945ffa4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,8 +22,9 @@ import java.text.SimpleDateFormat import java.util.Calendar import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{StringType, TimestampType, DateType} +import org.apache.spark.sql.types._ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -303,4 +304,60 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } + + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(0))) + checkEvaluation(FromUnixTime( + Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss")), sdf1.format(new Timestamp(1000000))) + checkEvaluation( + FromUnixTime(Literal(-1000L), Literal(fmt2)), sdf2.format(new Timestamp(-1000000))) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType)), null) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(FromUnixTime(Literal(1000L), Literal.create(null, StringType)), null) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("not a valid format")), null) + } + + test("unix_timestamp") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3) + val date1 = Date.valueOf("2015-07-24") + checkEvaluation( + UnixTimestamp(Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss")), 0L) + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + UnixTimestamp(Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), 1000L) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss")), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1)) / 1000L) + checkEvaluation( + UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2)), -1000L) + checkEvaluation(UnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3)), + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24"))) / 1000L) + val t1 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + UnixTimestamp(Literal.create(null, DateType), Literal.create(null, StringType)), null) + checkEvaluation( + UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss")), null) + checkEvaluation(UnixTimestamp( + Literal(date1), Literal.create(null, StringType)), date1.getTime / 1000L) + checkEvaluation( + UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index a2fece62f61f9..3f440e062eb96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2110,6 +2110,48 @@ object functions { */ def weekofyear(columnName: String): Column = weekofyear(Column(columnName)) + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string + * representing the timestamp of that moment in the current system time zone in the given + * format. + * @group datetime_funcs + * @since 1.5.0 + */ + def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + + /** + * Gets current Unix timestamp in seconds. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), + * using the default timezone and the default locale, return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + + /** + * Convert time string with given pattern + * (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) + * to Unix time stamp (in seconds), return null if fail. + * @group datetime_funcs + * @since 1.5.0 + */ + def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 07eb6e4a8d8cd..df4cb57ac5b21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -228,4 +228,60 @@ class DateFunctionsSuite extends QueryTest { Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) } + test("from_unixtime") { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2) + val fmt3 = "yy-MM-dd HH-mm-ss" + val sdf3 = new SimpleDateFormat(fmt3) + val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") + checkAnswer( + df.select(from_unixtime(col("a"))), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt2)), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt3)), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr("from_unixtime(a)"), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt2')"), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt3')"), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + } + + test("unix_timestamp") { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( + Row(date1.getTime / 1000L), Row(date2.getTime / 1000L))) + checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( + Row(ts1.getTime / 1000L), Row(ts2.getTime / 1000L))) + } + } From a20e743fb863de809863652931bc982aac2d1f86 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 13:09:43 -0700 Subject: [PATCH 173/219] [SPARK-9460] Fix prefix generation for UTF8String. Previously we could be getting garbage data if the number of bytes is 0, or on JVMs that are 4 byte aligned, or when compressedoops is on. Author: Reynold Xin Closes #7789 from rxin/utf8string and squashes the following commits: 86ffa3e [Reynold Xin] Mask out data outside of valid range. 4d647ed [Reynold Xin] Mask out data. c6e8794 [Reynold Xin] [SPARK-9460] Fix prefix generation for UTF8String. --- .../apache/spark/unsafe/types/UTF8String.java | 36 +++++++++++++++++-- .../spark/unsafe/types/UTF8StringSuite.java | 8 +++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 57522003ba2ba..c38953f65d7d7 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -65,6 +65,19 @@ public static UTF8String fromBytes(byte[] bytes) { } } + /** + * Creates an UTF8String from byte array, which should be encoded in UTF-8. + * + * Note: `bytes` will be hold by returned UTF8String. + */ + public static UTF8String fromBytes(byte[] bytes, int offset, int numBytes) { + if (bytes != null) { + return new UTF8String(bytes, BYTE_ARRAY_OFFSET + offset, numBytes); + } else { + return null; + } + } + /** * Creates an UTF8String from String. */ @@ -89,10 +102,10 @@ public static UTF8String blankString(int length) { return fromBytes(spaces); } - protected UTF8String(Object base, long offset, int size) { + protected UTF8String(Object base, long offset, int numBytes) { this.base = base; this.offset = offset; - this.numBytes = size; + this.numBytes = numBytes; } /** @@ -141,7 +154,24 @@ public int numChars() { * Returns a 64-bit integer that can be used as the prefix used in sorting. */ public long getPrefix() { - long p = PlatformDependent.UNSAFE.getLong(base, offset); + // Since JVMs are either 4-byte aligned or 8-byte aligned, we check the size of the string. + // If size is 0, just return 0. + // If size is between 0 and 4 (inclusive), assume data is 4-byte aligned under the hood and + // use a getInt to fetch the prefix. + // If size is greater than 4, assume we have at least 8 bytes of data to fetch. + // After getting the data, we use a mask to mask out data that is not part of the string. + long p; + if (numBytes >= 8) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + } else if (numBytes > 4) { + p = PlatformDependent.UNSAFE.getLong(base, offset); + p = p & ((1L << numBytes * 8) - 1); + } else if (numBytes > 0) { + p = (long) PlatformDependent.UNSAFE.getInt(base, offset); + p = p & ((1L << numBytes * 8) - 1); + } else { + p = 0; + } p = java.lang.Long.reverseBytes(p); return p; } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 42e09e435a412..f2cc19ca6b172 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -71,6 +71,14 @@ public void prefix() { fromString("abbbbbbbbbbbasdf").getPrefix() - fromString("bbbbbbbbbbbbasdf").getPrefix() < 0); assertTrue(fromString("").getPrefix() - fromString("a").getPrefix() < 0); assertTrue(fromString("你好").getPrefix() - fromString("世界").getPrefix() > 0); + + byte[] buf1 = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + byte[] buf2 = {1, 2, 3}; + UTF8String str1 = UTF8String.fromBytes(buf1, 0, 3); + UTF8String str2 = UTF8String.fromBytes(buf1, 0, 8); + UTF8String str3 = UTF8String.fromBytes(buf2); + assertTrue(str1.getPrefix() - str2.getPrefix() < 0); + assertEquals(str1.getPrefix(), str3.getPrefix()); } @Test From d8cfd531c7c50c9b00ab546be458f44f84c386ae Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 30 Jul 2015 13:17:54 -0700 Subject: [PATCH 174/219] [SPARK-5567] [MLLIB] Add predict method to LocalLDAModel jkbradley hhbyyh Adds `topicDistributions` to LocalLDAModel. Please review after #7757 is merged. Author: Feynman Liang Closes #7760 from feynmanliang/SPARK-5567-predict-in-LDA and squashes the following commits: 0ad1134 [Feynman Liang] Remove println 27b3877 [Feynman Liang] Code review fixes 6bfb87c [Feynman Liang] Remove extra newline 476f788 [Feynman Liang] Fix checks and doc for variationalInference 061780c [Feynman Liang] Code review cleanup 3be2947 [Feynman Liang] Rename topicDistribution -> topicDistributions 2a821a6 [Feynman Liang] Add predict methods to LocalLDAModel --- .../spark/mllib/clustering/LDAModel.scala | 42 +++++++++++-- .../spark/mllib/clustering/LDAOptimizer.scala | 5 +- .../spark/mllib/clustering/LDASuite.scala | 63 +++++++++++++++++++ 3 files changed, 102 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index ece28848aa02c..6cfad3fbbdb87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -186,7 +186,6 @@ abstract class LDAModel private[clustering] extends Saveable { * This model stores only the inferred topics. * It may be used for computing topics for new documents, but it may give less accurate answers * than the [[DistributedLDAModel]]. - * * @param topics Inferred topics (vocabSize x k matrix). */ @Experimental @@ -221,9 +220,6 @@ class LocalLDAModel private[clustering] ( // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? - // TODO: - // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? - /** * Calculate the log variational bound on perplexity. See Equation (16) in original Online * LDA paper. @@ -269,7 +265,7 @@ class LocalLDAModel private[clustering] ( // by topic (columns of lambda) val Elogbeta = LDAUtils.dirichletExpectation(lambda.t).t - var score = documents.filter(_._2.numActives > 0).map { case (id: Long, termCounts: Vector) => + var score = documents.filter(_._2.numNonzeros > 0).map { case (id: Long, termCounts: Vector) => var docScore = 0.0D val (gammad: BDV[Double], _) = OnlineLDAOptimizer.variationalTopicInference( termCounts, exp(Elogbeta), brzAlpha, gammaShape, k) @@ -277,7 +273,7 @@ class LocalLDAModel private[clustering] ( // E[log p(doc | theta, beta)] termCounts.foreachActive { case (idx, count) => - docScore += LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t) + docScore += count * LDAUtils.logSumExp(Elogthetad + Elogbeta(idx, ::).t) } // E[log p(theta | alpha) - log q(theta | gamma)]; assumes alpha is a vector docScore += sum((brzAlpha - gammad) :* Elogthetad) @@ -297,6 +293,40 @@ class LocalLDAModel private[clustering] ( score } + /** + * Predicts the topic mixture distribution for each document (often called "theta" in the + * literature). Returns a vector of zeros for an empty document. + * + * This uses a variational approximation following Hoffman et al. (2010), where the approximate + * distribution is called "gamma." Technically, this method returns this approximation "gamma" + * for each document. + * @param documents documents to predict topic mixture distributions for + * @return An RDD of (document ID, topic mixture distribution for document) + */ + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = { + // Double transpose because dirichletExpectation normalizes by row and we need to normalize + // by topic (columns of lambda) + val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.toBreeze.toDenseMatrix.t).t) + val docConcentrationBrz = this.docConcentration.toBreeze + val gammaShape = this.gammaShape + val k = this.k + + documents.map { case (id: Long, termCounts: Vector) => + if (termCounts.numNonzeros == 0) { + (id, Vectors.zeros(k)) + } else { + val (gamma, _) = OnlineLDAOptimizer.variationalTopicInference( + termCounts, + expElogbeta, + docConcentrationBrz, + gammaShape, + k) + (id, Vectors.dense(normalize(gamma, 1.0).toArray)) + } + } + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 4b90fbdf0ce7e..9dbec41efeada 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -394,7 +394,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val gammaShape = this.gammaShape val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => - val nonEmptyDocs = docs.filter(_._2.numActives > 0) + val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) val stat = BDM.zeros[Double](k, vocabSize) var gammaPart = List[BDV[Double]]() @@ -461,7 +461,8 @@ final class OnlineLDAOptimizer extends LDAOptimizer { private[clustering] object OnlineLDAOptimizer { /** * Uses variational inference to infer the topic distribution `gammad` given the term counts - * for a document. `termCounts` must be non-empty, otherwise Breeze will throw a BLAS error. + * for a document. `termCounts` must contain at least one non-zero entry, otherwise Breeze will + * throw a BLAS error. * * An optimization (Lee, Seung: Algorithms for non-negative matrix factorization, NIPS 2001) * avoids explicit computation of variational parameter `phi`. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 61d2edfd9fb5f..d74482d3a7598 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -242,6 +242,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { val alpha = 0.01 val eta = 0.01 val gammaShape = 100 + // obtained from LDA model trained in gensim, see below val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) @@ -281,6 +282,68 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D) } + test("LocalLDAModel predict") { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + // obtained from LDA model trained in gensim, see below + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + + def toydata: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + val docs = sc.parallelize(toydata) + + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + print(list(lda.get_document_topics(corpus))) + > [[(0, 0.99504950495049516)], [(0, 0.99504950495049516)], + > [(0, 0.99504950495049516)], [(1, 0.99504950495049516)], + > [(1, 0.99504950495049516)], [(1, 0.99504950495049516)]] + */ + + val expectedPredictions = List( + (0, 0.99504), (0, 0.99504), + (0, 0.99504), (1, 0.99504), + (1, 0.99504), (1, 0.99504)) + + val actualPredictions = ldaModel.topicDistributions(docs).map { case (id, topics) => + // convert results to expectedPredictions format, which only has highest probability topic + val topicsBz = topics.toBreeze.toDenseVector + (id, (argmax(topicsBz), max(topicsBz))) + }.sortByKey() + .values + .collect() + + expectedPredictions.zip(actualPredictions).forall { case (expected, actual) => + expected._1 === actual._1 && (expected._2 ~== actual._2 relTol 1E-3D) + } + } + test("OnlineLDAOptimizer with asymmetric prior") { def toydata: Array[(Long, Vector)] = Array( Vectors.sparse(6, Array(0, 1), Array(1, 1)), From 1abf7dc16ca1ba1777fe874c8b81fe6f2b0a6de5 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 13:21:46 -0700 Subject: [PATCH 175/219] [SPARK-8186] [SPARK-8187] [SPARK-8194] [SPARK-8198] [SPARK-9133] [SPARK-9290] [SQL] functions: date_add, date_sub, add_months, months_between, time-interval calculation This PR is based on #7589 , thanks to adrian-wang Added SQL function date_add, date_sub, add_months, month_between, also add a rule for add/subtract of date/timestamp and interval. Closes #7589 cc rxin Author: Daoyuan Wang Author: Davies Liu Closes #7754 from davies/date_add and squashes the following commits: e8c633a [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add 9e8e085 [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add 6224ce4 [Davies Liu] fix conclict bd18cd4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add e47ff2c [Davies Liu] add python api, fix date functions 01943d0 [Davies Liu] Merge branch 'master' into date_add 522e91a [Daoyuan Wang] fix e8a639a [Daoyuan Wang] fix 42df486 [Daoyuan Wang] fix style 87c4b77 [Daoyuan Wang] function add_months, months_between and some fixes 1a68e03 [Daoyuan Wang] poc of time interval calculation c506661 [Daoyuan Wang] function date_add , date_sub --- python/pyspark/sql/functions.py | 76 ++++++- .../catalyst/analysis/FunctionRegistry.scala | 4 + .../catalyst/analysis/HiveTypeCoercion.scala | 22 ++ .../expressions/datetimeFunctions.scala | 155 ++++++++++++- .../sql/catalyst/util/DateTimeUtils.scala | 139 ++++++++++++ .../analysis/HiveTypeCoercionSuite.scala | 30 +++ .../expressions/DateExpressionsSuite.scala | 176 +++++++++------ .../catalyst/util/DateTimeUtilsSuite.scala | 205 +++++++++++------- .../org/apache/spark/sql/functions.scala | 29 +++ .../apache/spark/sql/DateFunctionsSuite.scala | 117 ++++++++++ 10 files changed, 791 insertions(+), 162 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index d930f7db25d25..a7295e25f0aa5 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -59,7 +59,7 @@ __all__ += ['lag', 'lead', 'ntile'] __all__ += [ - 'date_format', + 'date_format', 'date_add', 'date_sub', 'add_months', 'months_between', 'year', 'quarter', 'month', 'hour', 'minute', 'second', 'dayofmonth', 'dayofyear', 'weekofyear'] @@ -716,7 +716,7 @@ def date_format(dateCol, format): [Row(date=u'04/08/2015')] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.date_format(dateCol, format)) + return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format)) @since(1.5) @@ -729,7 +729,7 @@ def year(col): [Row(year=2015)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.year(col)) + return Column(sc._jvm.functions.year(_to_java_column(col))) @since(1.5) @@ -742,7 +742,7 @@ def quarter(col): [Row(quarter=2)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.quarter(col)) + return Column(sc._jvm.functions.quarter(_to_java_column(col))) @since(1.5) @@ -755,7 +755,7 @@ def month(col): [Row(month=4)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.month(col)) + return Column(sc._jvm.functions.month(_to_java_column(col))) @since(1.5) @@ -768,7 +768,7 @@ def dayofmonth(col): [Row(day=8)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.dayofmonth(col)) + return Column(sc._jvm.functions.dayofmonth(_to_java_column(col))) @since(1.5) @@ -781,7 +781,7 @@ def dayofyear(col): [Row(day=98)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.dayofyear(col)) + return Column(sc._jvm.functions.dayofyear(_to_java_column(col))) @since(1.5) @@ -794,7 +794,7 @@ def hour(col): [Row(hour=13)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.hour(col)) + return Column(sc._jvm.functions.hour(_to_java_column(col))) @since(1.5) @@ -807,7 +807,7 @@ def minute(col): [Row(minute=8)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.minute(col)) + return Column(sc._jvm.functions.minute(_to_java_column(col))) @since(1.5) @@ -820,7 +820,7 @@ def second(col): [Row(second=15)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.second(col)) + return Column(sc._jvm.functions.second(_to_java_column(col))) @since(1.5) @@ -829,11 +829,63 @@ def weekofyear(col): Extract the week number of a given date as integer. >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a']) - >>> df.select(weekofyear('a').alias('week')).collect() + >>> df.select(weekofyear(df.a).alias('week')).collect() [Row(week=15)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.weekofyear(col)) + return Column(sc._jvm.functions.weekofyear(_to_java_column(col))) + + +@since(1.5) +def date_add(start, days): + """ + Returns the date that is `days` days after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_add(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 9))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_add(_to_java_column(start), days)) + + +@since(1.5) +def date_sub(start, days): + """ + Returns the date that is `days` days before `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(date_sub(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 4, 7))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.date_sub(_to_java_column(start), days)) + + +@since(1.5) +def add_months(start, months): + """ + Returns the date that is `months` months after `start` + + >>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d']) + >>> df.select(add_months(df.d, 1).alias('d')).collect() + [Row(d=datetime.date(2015, 5, 8))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.add_months(_to_java_column(start), months)) + + +@since(1.5) +def months_between(date1, date2): + """ + Returns the number of months between date1 and date2. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd']) + >>> df.select(months_between(df.t, df.d).alias('months')).collect() + [Row(months=3.9495967...)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d663f12bc6d0d..6c7c481fab8db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -205,9 +205,12 @@ object FunctionRegistry { expression[Upper]("upper"), // datetime functions + expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), + expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), + expression[DateSub]("date_sub"), expression[DayOfMonth]("day"), expression[DayOfYear]("dayofyear"), expression[DayOfMonth]("dayofmonth"), @@ -216,6 +219,7 @@ object FunctionRegistry { expression[LastDay]("last_day"), expression[Minute]("minute"), expression[Month]("month"), + expression[MonthsBetween]("months_between"), expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index ecc48986e35d8..603afc4032a37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -47,6 +47,7 @@ object HiveTypeCoercion { Division :: PropagateTypes :: ImplicitTypeCasts :: + DateTimeOperations :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -638,6 +639,27 @@ object HiveTypeCoercion { } } + /** + * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType + * to TimeAdd/TimeSub + */ + object DateTimeOperations extends Rule[LogicalPlan] { + + private val acceptedTypes = Seq(DateType, TimestampType, StringType) + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes who's children have not been resolved yet. + case e if !e.childrenResolved => e + + case Add(l @ CalendarIntervalType(), r) if acceptedTypes.contains(r.dataType) => + Cast(TimeAdd(r, l), r.dataType) + case Add(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) => + Cast(TimeAdd(l, r), l.dataType) + case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) => + Cast(TimeSub(l, r), l.dataType) + } + } + /** * Casts types according to the expected input types for [[Expression]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index a5e6249e438d2..9795673ee0664 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import scala.util.Try @@ -63,6 +63,53 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { } } +/** + * Adds a number of days to startdate. + */ +case class DateAdd(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] + d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd + $d;""" + }) + } +} + +/** + * Subtracts a number of days to startdate. + */ +case class DateSub(startDate: Expression, days: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = startDate + override def right: Expression = days + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, d: Any): Any = { + start.asInstanceOf[Int] - d.asInstanceOf[Int] + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (sd, d) => { + s"""${ev.primitive} = $sd - $d;""" + }) + } +} + case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -543,3 +590,109 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) override def prettyName: String = "next_day" } + +/** + * Adds an interval to timestamp. + */ +case class TimeAdd(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left + $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], itvl.months, itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" + }) + } +} + +/** + * Subtracts an interval from timestamp. + */ +case class TimeSub(start: Expression, interval: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = start + override def right: Expression = interval + + override def toString: String = s"$left - $right" + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType) + + override def dataType: DataType = TimestampType + + override def nullSafeEval(start: Any, interval: Any): Any = { + val itvl = interval.asInstanceOf[CalendarInterval] + DateTimeUtils.timestampAddInterval( + start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, i) => { + s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" + }) + } +} + +/** + * Returns the date that is num_months after start_date. + */ +case class AddMonths(startDate: Expression, numMonths: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = startDate + override def right: Expression = numMonths + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType) + + override def dataType: DataType = DateType + + override def nullSafeEval(start: Any, months: Any): Any = { + DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (sd, m) => { + s"""$dtu.dateAddMonths($sd, $m)""" + }) + } +} + +/** + * Returns number of months between dates date1 and date2. + */ +case class MonthsBetween(date1: Expression, date2: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = date1 + override def right: Expression = date2 + + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType) + + override def dataType: DataType = DoubleType + + override def nullSafeEval(t1: Any, t2: Any): Any = { + DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + defineCodeGen(ctx, ev, (l, r) => { + s"""$dtu.monthsBetween($l, $r)""" + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 93966a503c27c..53abdf6618eac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -45,6 +45,7 @@ object DateTimeUtils { final val to2001 = -11323 // this is year -17999, calculation: 50 * daysIn400Year + final val YearZero = -17999 final val toYearZero = to2001 + 7304850 @transient lazy val defaultTimeZone = TimeZone.getDefault @@ -575,6 +576,144 @@ object DateTimeUtils { } /** + * The number of days for each month (not leap year) + */ + private val monthDays = Array(31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31) + + /** + * Returns the date value for the first day of the given month. + * The month is expressed in months since year zero (17999 BC), starting from 0. + */ + private def firstDayOfMonth(absoluteMonth: Int): Int = { + val absoluteYear = absoluteMonth / 12 + var monthInYear = absoluteMonth - absoluteYear * 12 + var date = getDateFromYear(absoluteYear) + if (monthInYear >= 2 && isLeapYear(absoluteYear + YearZero)) { + date += 1 + } + while (monthInYear > 0) { + date += monthDays(monthInYear - 1) + monthInYear -= 1 + } + date + } + + /** + * Returns the date value for January 1 of the given year. + * The year is expressed in years since year zero (17999 BC), starting from 0. + */ + private def getDateFromYear(absoluteYear: Int): Int = { + val absoluteDays = (absoluteYear * 365 + absoluteYear / 400 - absoluteYear / 100 + + absoluteYear / 4) + absoluteDays - toYearZero + } + + /** + * Add date and year-month interval. + * Returns a date value, expressed in days since 1.1.1970. + */ + def dateAddMonths(days: Int, months: Int): Int = { + val absoluteMonth = (getYear(days) - YearZero) * 12 + getMonth(days) - 1 + months + val currentMonthInYear = absoluteMonth % 12 + val currentYear = absoluteMonth / 12 + val leapDay = if (currentMonthInYear == 1 && isLeapYear(currentYear + YearZero)) 1 else 0 + val lastDayOfMonth = monthDays(currentMonthInYear) + leapDay + + val dayOfMonth = getDayOfMonth(days) + val currentDayInMonth = if (getDayOfMonth(days + 1) == 1 || dayOfMonth >= lastDayOfMonth) { + // last day of the month + lastDayOfMonth + } else { + dayOfMonth + } + firstDayOfMonth(absoluteMonth) + currentDayInMonth - 1 + } + + /** + * Add timestamp and full interval. + * Returns a timestamp value, expressed in microseconds since 1.1.1970 00:00:00. + */ + def timestampAddInterval(start: Long, months: Int, microseconds: Long): Long = { + val days = millisToDays(start / 1000L) + val newDays = dateAddMonths(days, months) + daysToMillis(newDays) * 1000L + start - daysToMillis(days) * 1000L + microseconds + } + + /** + * Returns the last dayInMonth in the month it belongs to. The date is expressed + * in days since 1.1.1970. the return value starts from 1. + */ + private def getLastDayInMonthOfMonth(date: Int): Int = { + var (year, dayInYear) = getYearAndDayInYear(date) + if (isLeapYear(year)) { + if (dayInYear > 31 && dayInYear <= 60) { + return 29 + } else if (dayInYear > 60) { + dayInYear = dayInYear - 1 + } + } + if (dayInYear <= 31) { + 31 + } else if (dayInYear <= 59) { + 28 + } else if (dayInYear <= 90) { + 31 + } else if (dayInYear <= 120) { + 30 + } else if (dayInYear <= 151) { + 31 + } else if (dayInYear <= 181) { + 30 + } else if (dayInYear <= 212) { + 31 + } else if (dayInYear <= 243) { + 31 + } else if (dayInYear <= 273) { + 30 + } else if (dayInYear <= 304) { + 31 + } else if (dayInYear <= 334) { + 30 + } else { + 31 + } + } + + /** + * Returns number of months between time1 and time2. time1 and time2 are expressed in + * microseconds since 1.1.1970. + * + * If time1 and time2 having the same day of month, or both are the last day of month, + * it returns an integer (time under a day will be ignored). + * + * Otherwise, the difference is calculated based on 31 days per month, and rounding to + * 8 digits. + */ + def monthsBetween(time1: Long, time2: Long): Double = { + val millis1 = time1 / 1000L + val millis2 = time2 / 1000L + val date1 = millisToDays(millis1) + val date2 = millisToDays(millis2) + // TODO(davies): get year, month, dayOfMonth from single function + val dayInMonth1 = getDayOfMonth(date1) + val dayInMonth2 = getDayOfMonth(date2) + val months1 = getYear(date1) * 12 + getMonth(date1) + val months2 = getYear(date2) * 12 + getMonth(date2) + + if (dayInMonth1 == dayInMonth2 || (dayInMonth1 == getLastDayInMonthOfMonth(date1) + && dayInMonth2 == getLastDayInMonthOfMonth(date2))) { + return (months1 - months2).toDouble + } + // milliseconds is enough for 8 digits precision on the right side + val timeInDay1 = millis1 - daysToMillis(date1) + val timeInDay2 = millis2 - daysToMillis(date2) + val timesBetween = (timeInDay1 - timeInDay2).toDouble / MILLIS_PER_DAY + val diff = (months1 - months2).toDouble + (dayInMonth1 - dayInMonth2 + timesBetween) / 31.0 + // rounding to 8 digits + math.round(diff * 1e8) / 1e8 + } + + /* * Returns day of week from String. Starting from Thursday, marked as 0. * (Because 1970-01-01 is Thursday). */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 1d9ee5ddf3a5a..70608771dd110 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.analysis +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval class HiveTypeCoercionSuite extends PlanTest { @@ -400,6 +403,33 @@ class HiveTypeCoercionSuite extends PlanTest { } } + test("rule for date/timestamp operations") { + val dateTimeOperations = HiveTypeCoercion.DateTimeOperations + val date = Literal(new java.sql.Date(0L)) + val timestamp = Literal(new Timestamp(0L)) + val interval = Literal(new CalendarInterval(0, 0)) + val str = Literal("2015-01-01") + + ruleTest(dateTimeOperations, Add(date, interval), Cast(TimeAdd(date, interval), DateType)) + ruleTest(dateTimeOperations, Add(interval, date), Cast(TimeAdd(date, interval), DateType)) + ruleTest(dateTimeOperations, Add(timestamp, interval), + Cast(TimeAdd(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Add(interval, timestamp), + Cast(TimeAdd(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Add(str, interval), Cast(TimeAdd(str, interval), StringType)) + ruleTest(dateTimeOperations, Add(interval, str), Cast(TimeAdd(str, interval), StringType)) + + ruleTest(dateTimeOperations, Subtract(date, interval), Cast(TimeSub(date, interval), DateType)) + ruleTest(dateTimeOperations, Subtract(timestamp, interval), + Cast(TimeSub(timestamp, interval), TimestampType)) + ruleTest(dateTimeOperations, Subtract(str, interval), Cast(TimeSub(str, interval), StringType)) + + // interval operations should not be effected + ruleTest(dateTimeOperations, Add(interval, interval), Add(interval, interval)) + ruleTest(dateTimeOperations, Subtract(interval, interval), Subtract(interval, interval)) + } + + /** * There are rules that need to not fire before child expressions get resolved. * We use this test to make sure those rules do not fire early. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index e1387f945ffa4..fd1d6c1d25497 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -22,8 +22,8 @@ import java.text.SimpleDateFormat import java.util.Calendar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.sql.types._ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -48,56 +48,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("DayOfYear") { val sdfDay = new SimpleDateFormat("D") - (2002 to 2004).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - (1998 to 2002).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - - (1969 to 1970).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - - (2402 to 2404).foreach { y => - (0 to 11).foreach { m => - (0 to 5).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.DATE, i) - checkEvaluation(DayOfYear(Literal(new Date(c.getTimeInMillis))), - sdfDay.format(c.getTime).toInt) - } - } - } - - (2398 to 2402).foreach { y => - (0 to 11).foreach { m => + (0 to 3).foreach { m => (0 to 5).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) @@ -117,7 +69,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Year(Cast(Literal(ts), DateType)), 2013) val c = Calendar.getInstance() - (2000 to 2010).foreach { y => + (2000 to 2002).foreach { y => (0 to 11 by 11).foreach { m => c.set(y, m, 28) (0 to 5 * 24).foreach { i => @@ -155,20 +107,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Month(Cast(Literal(ts), DateType)), 11) (2003 to 2004).foreach { y => - (0 to 11).foreach { m => - (0 to 5 * 24).foreach { i => - val c = Calendar.getInstance() - c.set(y, m, 28, 0, 0, 0) - c.add(Calendar.HOUR_OF_DAY, i) - checkEvaluation(Month(Literal(new Date(c.getTimeInMillis))), - c.get(Calendar.MONTH) + 1) - } - } - } - - (1999 to 2000).foreach { y => - (0 to 11).foreach { m => - (0 to 5 * 24).foreach { i => + (0 to 3).foreach { m => + (0 to 2 * 24).foreach { i => val c = Calendar.getInstance() c.set(y, m, 28, 0, 0, 0) c.add(Calendar.HOUR_OF_DAY, i) @@ -262,6 +202,112 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("date_add") { + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29"))) + checkEvaluation( + DateAdd(Literal(Date.valueOf("2016-02-28")), Literal(-365)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28"))) + checkEvaluation(DateAdd(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(DateAdd(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)), + null) + checkEvaluation(DateAdd(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("date_sub") { + checkEvaluation( + DateSub(Literal(Date.valueOf("2015-01-01")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2014-12-31"))) + checkEvaluation( + DateSub(Literal(Date.valueOf("2015-01-01")), Literal(-1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-01-02"))) + checkEvaluation(DateSub(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(DateSub(Literal(Date.valueOf("2016-02-28")), Literal.create(null, IntegerType)), + null) + checkEvaluation(DateSub(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("time_add") { + checkEvaluation( + TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal(new CalendarInterval(1, 123000L))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00.123"))) + + checkEvaluation( + TimeAdd(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), + null) + checkEvaluation( + TimeAdd(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal.create(null, CalendarIntervalType)), + null) + checkEvaluation( + TimeAdd(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), + null) + } + + test("time_sub") { + checkEvaluation( + TimeSub(Literal(Timestamp.valueOf("2016-03-31 10:00:00")), + Literal(new CalendarInterval(1, 0))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-29 10:00:00"))) + checkEvaluation( + TimeSub( + Literal(Timestamp.valueOf("2016-03-30 00:00:01")), + Literal(new CalendarInterval(1, 2000000.toLong))), + DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-02-28 23:59:59"))) + + checkEvaluation( + TimeSub(Literal.create(null, TimestampType), Literal(new CalendarInterval(1, 123000L))), + null) + checkEvaluation( + TimeSub(Literal(Timestamp.valueOf("2016-01-29 10:00:00")), + Literal.create(null, CalendarIntervalType)), + null) + checkEvaluation( + TimeSub(Literal.create(null, TimestampType), Literal.create(null, CalendarIntervalType)), + null) + } + + test("add_months") { + checkEvaluation(AddMonths(Literal(Date.valueOf("2015-01-30")), Literal(1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-02-28"))) + checkEvaluation(AddMonths(Literal(Date.valueOf("2016-03-30")), Literal(-1)), + DateTimeUtils.fromJavaDate(Date.valueOf("2016-02-29"))) + checkEvaluation( + AddMonths(Literal(Date.valueOf("2015-01-30")), Literal.create(null, IntegerType)), + null) + checkEvaluation(AddMonths(Literal.create(null, DateType), Literal(1)), null) + checkEvaluation(AddMonths(Literal.create(null, DateType), Literal.create(null, IntegerType)), + null) + } + + test("months_between") { + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("1997-02-28 10:30:00")), + Literal(Timestamp.valueOf("1996-10-30 00:00:00"))), + 3.94959677) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-01-30 11:52:00")), + Literal(Timestamp.valueOf("2015-01-30 11:50:00"))), + 0.0) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-01-31 00:00:00")), + Literal(Timestamp.valueOf("2015-03-31 22:00:00"))), + -2.0) + checkEvaluation( + MonthsBetween(Literal(Timestamp.valueOf("2015-03-31 22:00:00")), + Literal(Timestamp.valueOf("2015-02-28 00:00:00"))), + 1.0) + val t = Literal(Timestamp.valueOf("2015-03-31 22:00:00")) + val tnull = Literal.create(null, TimestampType) + checkEvaluation(MonthsBetween(t, tnull), null) + checkEvaluation(MonthsBetween(tnull, t), null) + checkEvaluation(MonthsBetween(tnull, tnull), null) + } + test("last_day") { checkEvaluation(LastDay(Literal(Date.valueOf("2015-02-28"))), Date.valueOf("2015-02-28")) checkEvaluation(LastDay(Literal(Date.valueOf("2015-03-27"))), Date.valueOf("2015-03-31")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index fab9eb9cd4c9f..60d2bcfe13757 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -19,47 +19,48 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{TimeZone, Calendar} +import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ class DateTimeUtilsSuite extends SparkFunSuite { private[this] def getInUTCDays(timestamp: Long): Int = { val tz = TimeZone.getDefault - ((timestamp + tz.getOffset(timestamp)) / DateTimeUtils.MILLIS_PER_DAY).toInt + ((timestamp + tz.getOffset(timestamp)) / MILLIS_PER_DAY).toInt } test("timestamp and us") { val now = new Timestamp(System.currentTimeMillis()) now.setNanos(1000) - val ns = DateTimeUtils.fromJavaTimestamp(now) + val ns = fromJavaTimestamp(now) assert(ns % 1000000L === 1) - assert(DateTimeUtils.toJavaTimestamp(ns) === now) + assert(toJavaTimestamp(ns) === now) List(-111111111111L, -1L, 0, 1L, 111111111111L).foreach { t => - val ts = DateTimeUtils.toJavaTimestamp(t) - assert(DateTimeUtils.fromJavaTimestamp(ts) === t) - assert(DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJavaTimestamp(ts)) === ts) + val ts = toJavaTimestamp(t) + assert(fromJavaTimestamp(ts) === t) + assert(toJavaTimestamp(fromJavaTimestamp(ts)) === ts) } } test("us and julian day") { - val (d, ns) = DateTimeUtils.toJulianDay(0) - assert(d === DateTimeUtils.JULIAN_DAY_OF_EPOCH) - assert(ns === DateTimeUtils.SECONDS_PER_DAY / 2 * DateTimeUtils.NANOS_PER_SECOND) - assert(DateTimeUtils.fromJulianDay(d, ns) == 0L) + val (d, ns) = toJulianDay(0) + assert(d === JULIAN_DAY_OF_EPOCH) + assert(ns === SECONDS_PER_DAY / 2 * NANOS_PER_SECOND) + assert(fromJulianDay(d, ns) == 0L) val t = new Timestamp(61394778610000L) // (2015, 6, 11, 10, 10, 10, 100) - val (d1, ns1) = DateTimeUtils.toJulianDay(DateTimeUtils.fromJavaTimestamp(t)) - val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1)) + val (d1, ns1) = toJulianDay(fromJavaTimestamp(t)) + val t2 = toJavaTimestamp(fromJulianDay(d1, ns1)) assert(t.equals(t2)) } test("SPARK-6785: java date conversion before and after epoch") { def checkFromToJavaDate(d1: Date): Unit = { - val d2 = DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d1)) + val d2 = toJavaDate(fromJavaDate(d1)) assert(d2.toString === d1.toString) } @@ -95,157 +96,156 @@ class DateTimeUtilsSuite extends SparkFunSuite { } test("string to date") { - import DateTimeUtils.millisToDays var c = Calendar.getInstance() c.set(2015, 0, 28, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-01-28")).get === + assert(stringToDate(UTF8String.fromString("2015-01-28")).get === millisToDays(c.getTimeInMillis)) c.set(2015, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015")).get === + assert(stringToDate(UTF8String.fromString("2015")).get === millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03")).get === + assert(stringToDate(UTF8String.fromString("2015-03")).get === millisToDays(c.getTimeInMillis)) c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 ")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18 ")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18 123142")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18 123142")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T123123")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18T123123")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18T")).get === + assert(stringToDate(UTF8String.fromString("2015-03-18T")).get === millisToDays(c.getTimeInMillis)) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("20150318")).isEmpty) - assert(DateTimeUtils.stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015/03/18")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015.03.18")).isEmpty) + assert(stringToDate(UTF8String.fromString("20150318")).isEmpty) + assert(stringToDate(UTF8String.fromString("2015-031-8")).isEmpty) } test("string to timestamp") { var c = Calendar.getInstance() c.set(1969, 11, 31, 16, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === + assert(stringToTimestamp(UTF8String.fromString("1969-12-31 16:00:00")).get === c.getTimeInMillis * 1000) c.set(2015, 0, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015")).get === + assert(stringToTimestamp(UTF8String.fromString("2015")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 1, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 ")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-13:53")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17-13:53")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17Z")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 12:03:17Z")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T12:03:17-1:0")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17-01:00")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17+07:03")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance() c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18 12:03:17.123")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("UTC")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 456) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.456Z")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18 12:03:17.456Z")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123-1:0")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123-01:00")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123+07:30")).get === c.getTimeInMillis * 1000) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.123121+7:30")).get === c.getTimeInMillis * 1000 + 121) c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30")) c.set(2015, 2, 18, 12, 3, 17) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03:17.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) @@ -254,7 +254,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 0) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("18:12:15")).get === c.getTimeInMillis * 1000) @@ -263,7 +263,7 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("T18:12:15.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) @@ -272,93 +272,130 @@ class DateTimeUtilsSuite extends SparkFunSuite { c.set(Calendar.MINUTE, 12) c.set(Calendar.SECOND, 15) c.set(Calendar.MILLISECOND, 123) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("18:12:15.12312+7:30")).get === c.getTimeInMillis * 1000 + 120) c = Calendar.getInstance() c.set(2011, 4, 6, 7, 8, 9) c.set(Calendar.MILLISECOND, 100) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2011-05-06 07:08:09.1000")).get === c.getTimeInMillis * 1000) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("238")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) - assert(DateTimeUtils.stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp(UTF8String.fromString("238")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18 123142")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18T123123")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-03-18X")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015/03/18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015.03.18")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("20150318")).isEmpty) + assert(stringToTimestamp(UTF8String.fromString("2015-031-8")).isEmpty) + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-20:0")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-0:70")).isEmpty) - assert(DateTimeUtils.stringToTimestamp( + assert(stringToTimestamp( UTF8String.fromString("2015-03-18T12:03.17-1:0:0")).isEmpty) } test("hours") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 13) + assert(getHours(c.getTimeInMillis * 1000) === 13) c.set(2015, 12, 8, 2, 7, 9) - assert(DateTimeUtils.getHours(c.getTimeInMillis * 1000) === 2) + assert(getHours(c.getTimeInMillis * 1000) === 2) } test("minutes") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 2) + assert(getMinutes(c.getTimeInMillis * 1000) === 2) c.set(2015, 2, 8, 2, 7, 9) - assert(DateTimeUtils.getMinutes(c.getTimeInMillis * 1000) === 7) + assert(getMinutes(c.getTimeInMillis * 1000) === 7) } test("seconds") { val c = Calendar.getInstance() c.set(2015, 2, 18, 13, 2, 11) - assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 11) + assert(getSeconds(c.getTimeInMillis * 1000) === 11) c.set(2015, 2, 8, 2, 7, 9) - assert(DateTimeUtils.getSeconds(c.getTimeInMillis * 1000) === 9) + assert(getSeconds(c.getTimeInMillis * 1000) === 9) } test("get day in year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 77) c.set(2012, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) + assert(getDayInYear(getInUTCDays(c.getTimeInMillis)) === 78) } test("get year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2015) + assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2015) c.set(2012, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getYear(getInUTCDays(c.getTimeInMillis)) === 2012) + assert(getYear(getInUTCDays(c.getTimeInMillis)) === 2012) } test("get quarter") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) + assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 1) c.set(2012, 11, 18, 0, 0, 0) - assert(DateTimeUtils.getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) + assert(getQuarter(getInUTCDays(c.getTimeInMillis)) === 4) } test("get month") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 3) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 3) c.set(2012, 11, 18, 0, 0, 0) - assert(DateTimeUtils.getMonth(getInUTCDays(c.getTimeInMillis)) === 12) + assert(getMonth(getInUTCDays(c.getTimeInMillis)) === 12) } test("get day of month") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) - assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) + assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 18) c.set(2012, 11, 24, 0, 0, 0) - assert(DateTimeUtils.getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + assert(getDayOfMonth(getInUTCDays(c.getTimeInMillis)) === 24) + } + + test("date add months") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + val days1 = millisToDays(c1.getTimeInMillis) + val c2 = Calendar.getInstance() + c2.set(2000, 1, 29) + assert(dateAddMonths(days1, 36) === millisToDays(c2.getTimeInMillis)) + c2.set(1996, 0, 31) + assert(dateAddMonths(days1, -13) === millisToDays(c2.getTimeInMillis)) + } + + test("timestamp add months") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + c1.set(Calendar.MILLISECOND, 0) + val ts1 = c1.getTimeInMillis * 1000L + val c2 = Calendar.getInstance() + c2.set(2000, 1, 29, 10, 30, 0) + c2.set(Calendar.MILLISECOND, 123) + val ts2 = c2.getTimeInMillis * 1000L + assert(timestampAddInterval(ts1, 36, 123000) === ts2) + } + + test("monthsBetween") { + val c1 = Calendar.getInstance() + c1.set(1997, 1, 28, 10, 30, 0) + val c2 = Calendar.getInstance() + c2.set(1996, 9, 30, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 3.94959677) + c2.set(2000, 1, 28, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) + c2.set(2000, 1, 29, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === -36) + c2.set(1996, 2, 31, 0, 0, 0) + assert(monthsBetween(c1.getTimeInMillis * 1000L, c2.getTimeInMillis * 1000L) === 11) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3f440e062eb96..168894d66117d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1927,6 +1927,14 @@ object functions { // DateTime functions ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Returns the date that is numMonths after startDate. + * @group datetime_funcs + * @since 1.5.0 + */ + def add_months(startDate: Column, numMonths: Int): Column = + AddMonths(startDate.expr, Literal(numMonths)) + /** * Converts a date/timestamp/string to a value of string in the format specified by the date * format given by the second argument. @@ -1959,6 +1967,20 @@ object functions { def date_format(dateColumnName: String, format: String): Column = date_format(Column(dateColumnName), format) + /** + * Returns the date that is `days` days after `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + + /** + * Returns the date that is `days` days before `start` + * @group datetime_funcs + * @since 1.5.0 + */ + def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + /** * Extracts the year as an integer from a given date/timestamp/string. * @group datetime_funcs @@ -2067,6 +2089,13 @@ object functions { */ def minute(columnName: String): Column = minute(Column(columnName)) + /* + * Returns number of months between dates `date1` and `date2`. + * @group datetime_funcs + * @since 1.5.0 + */ + def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + /** * Given a date column, returns the first date which is later than the value of the date column * that is on the specified day of the week. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index df4cb57ac5b21..b7267c413165a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -22,6 +22,7 @@ import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ +import org.apache.spark.unsafe.types.CalendarInterval class DateFunctionsSuite extends QueryTest { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext @@ -206,6 +207,122 @@ class DateFunctionsSuite extends QueryTest { Row(15, 15, 15)) } + test("function date_add") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_add(col("d"), 1)), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + checkAnswer( + df.select(date_add(col("t"), 3)), + Seq(Row(Date.valueOf("2015-06-04")), Row(Date.valueOf("2015-06-05")))) + checkAnswer( + df.select(date_add(col("s"), 5)), + Seq(Row(Date.valueOf("2015-06-06")), Row(Date.valueOf("2015-06-07")))) + checkAnswer( + df.select(date_add(col("ss"), 7)), + Seq(Row(Date.valueOf("2015-06-08")), Row(Date.valueOf("2015-06-09")))) + + checkAnswer(df.selectExpr("DATE_ADD(null, 1)"), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_ADD(d, 1)"""), + Seq(Row(Date.valueOf("2015-06-02")), Row(Date.valueOf("2015-06-03")))) + } + + test("function date_sub") { + val st1 = "2015-06-01 12:34:56" + val st2 = "2015-06-02 12:34:56" + val t1 = Timestamp.valueOf(st1) + val t2 = Timestamp.valueOf(st2) + val s1 = "2015-06-01" + val s2 = "2015-06-02" + val d1 = Date.valueOf(s1) + val d2 = Date.valueOf(s2) + val df = Seq((t1, d1, s1, st1), (t2, d2, s2, st2)).toDF("t", "d", "s", "ss") + checkAnswer( + df.select(date_sub(col("d"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("t"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("s"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(col("ss"), 1)), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + checkAnswer( + df.select(date_sub(lit(null), 1)).limit(1), Row(null)) + + checkAnswer(df.selectExpr("""DATE_SUB(d, null)"""), Seq(Row(null), Row(null))) + checkAnswer( + df.selectExpr("""DATE_SUB(d, 1)"""), + Seq(Row(Date.valueOf("2015-05-31")), Row(Date.valueOf("2015-06-01")))) + } + + test("time_add") { + val t1 = Timestamp.valueOf("2015-07-31 23:59:59") + val t2 = Timestamp.valueOf("2015-12-31 00:00:00") + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-12-31") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d + $i"), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2016-02-29")))) + checkAnswer( + df.selectExpr(s"t + $i"), + Seq(Row(Timestamp.valueOf("2015-10-01 00:00:01")), + Row(Timestamp.valueOf("2016-02-29 00:00:02")))) + } + + test("time_sub") { + val t1 = Timestamp.valueOf("2015-10-01 00:00:01") + val t2 = Timestamp.valueOf("2016-02-29 00:00:02") + val d1 = Date.valueOf("2015-09-30") + val d2 = Date.valueOf("2016-02-29") + val i = new CalendarInterval(2, 2000000L) + val df = Seq((1, t1, d1), (3, t2, d2)).toDF("n", "t", "d") + checkAnswer( + df.selectExpr(s"d - $i"), + Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-12-30")))) + checkAnswer( + df.selectExpr(s"t - $i"), + Seq(Row(Timestamp.valueOf("2015-07-31 23:59:59")), + Row(Timestamp.valueOf("2015-12-31 00:00:00")))) + } + + test("function add_months") { + val d1 = Date.valueOf("2015-08-31") + val d2 = Date.valueOf("2015-02-28") + val df = Seq((1, d1), (2, d2)).toDF("n", "d") + checkAnswer( + df.select(add_months(col("d"), 1)), + Seq(Row(Date.valueOf("2015-09-30")), Row(Date.valueOf("2015-03-31")))) + checkAnswer( + df.selectExpr("add_months(d, -1)"), + Seq(Row(Date.valueOf("2015-07-31")), Row(Date.valueOf("2015-01-31")))) + } + + test("function months_between") { + val d1 = Date.valueOf("2015-07-31") + val d2 = Date.valueOf("2015-02-16") + val t1 = Timestamp.valueOf("2014-09-30 23:30:00") + val t2 = Timestamp.valueOf("2015-09-16 12:00:00") + val s1 = "2014-09-15 11:30:00" + val s2 = "2015-10-01 00:00:00" + val df = Seq((t1, d1, s1), (t2, d2, s2)).toDF("t", "d", "s") + checkAnswer(df.select(months_between(col("t"), col("d"))), Seq(Row(-10.0), Row(7.0))) + checkAnswer(df.selectExpr("months_between(t, s)"), Seq(Row(0.5), Row(-0.5))) + } + test("function last_day") { val df1 = Seq((1, "2015-07-23"), (2, "2015-07-24")).toDF("i", "d") val df2 = Seq((1, "2015-07-23 00:11:22"), (2, "2015-07-24 11:22:33")).toDF("i", "t") From 89cda69ecd5ef942a68ad13fc4e1f4184010f087 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Thu, 30 Jul 2015 14:08:59 -0700 Subject: [PATCH 176/219] [SPARK-9454] Change LDASuite tests to use vector comparisons jkbradley Changes the current hacky string-comparison for vector compares. Author: Feynman Liang Closes #7775 from feynmanliang/SPARK-9454-ldasuite-vector-compare and squashes the following commits: bd91a82 [Feynman Liang] Remove println 905c76e [Feynman Liang] Fix string compare in distributed EM 2f24c13 [Feynman Liang] Improve LDASuite tests --- .../spark/mllib/clustering/LDASuite.scala | 33 ++++++++----------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index d74482d3a7598..c43e1e575c09c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -83,21 +83,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.topicsMatrix === localModel.topicsMatrix) // Check: topic summaries - // The odd decimal formatting and sorting is a hack to do a robust comparison. - val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => - assert(t1 === t2) + val topicSummary = model.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) => + assert(topics ~== topicsLocal absTol 0.01) } // Check: per-doc topic distributions @@ -197,10 +190,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // verify the result, Note this generate the identical result as // [[https://github.com/Blei-Lab/onlineldavb]] - val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1) - assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2) + val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t) + val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t) + val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950) + val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050) + assert(topic1 ~== expectedTopic1 absTol 0.01) + assert(topic2 ~== expectedTopic2 absTol 0.01) } test("OnlineLDAOptimizer with toy data") { From 0dbd6963d589a8f6ad344273f3da7df680ada515 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 30 Jul 2015 15:39:46 -0700 Subject: [PATCH 177/219] [SPARK-9479] [STREAMING] [TESTS] Fix ReceiverTrackerSuite failure for maven build and other potential test failures in Streaming See https://issues.apache.org/jira/browse/SPARK-9479 for the failure cause. The PR includes the following changes: 1. Make ReceiverTrackerSuite create StreamingContext in the test body. 2. Fix places that don't stop StreamingContext. I verified no SparkContext was stopped in the shutdown hook locally after this fix. 3. Fix an issue that `ReceiverTracker.endpoint` may be null. 4. Make sure stopping SparkContext in non-main thread won't fail other tests. Author: zsxwing Closes #7797 from zsxwing/fix-ReceiverTrackerSuite and squashes the following commits: 3a4bb98 [zsxwing] Fix another potential NPE d7497df [zsxwing] Fix ReceiverTrackerSuite; make sure StreamingContext in tests is closed --- .../StreamingLogisticRegressionSuite.scala | 21 +++++-- .../clustering/StreamingKMeansSuite.scala | 17 ++++-- .../StreamingLinearRegressionSuite.scala | 21 +++++-- .../streaming/scheduler/ReceiverTracker.scala | 12 +++- .../apache/spark/streaming/JavaAPISuite.java | 1 + .../streaming/BasicOperationsSuite.scala | 58 ++++++++++--------- .../spark/streaming/InputStreamsSuite.scala | 38 ++++++------ .../spark/streaming/MasterFailureTest.scala | 8 ++- .../streaming/StreamingContextSuite.scala | 22 +++++-- .../streaming/StreamingListenerSuite.scala | 13 ++++- .../scheduler/ReceiverTrackerSuite.scala | 56 +++++++++--------- .../StreamingJobProgressListenerSuite.scala | 19 ++++-- 12 files changed, 183 insertions(+), 103 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index fd653296c9d97..d7b291d5a6330 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -24,13 +24,22 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Test if we can accurately learn B for Y = logistic(BX) on streaming data test("parameter accuracy") { @@ -50,7 +59,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -84,7 +93,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - B))) inputDStream.count() @@ -118,7 +127,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -147,7 +156,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -167,7 +176,7 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase .setNumIterations(10) val numBatches = 10 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index ac01622b8a089..3645d29dccdb2 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.random.XORShiftRandom @@ -28,6 +28,15 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { override def maxWaitTimeMillis: Int = 30000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + test("accuracy for single center and equivalence to grand average") { // set parameters val numBatches = 10 @@ -46,7 +55,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -82,7 +91,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) @@ -114,7 +123,7 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) // setup and run the model training - val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { kMeans.trainOn(inputDStream) inputDStream.count() }) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index a2a4c5f6b8b70..34c07ed170816 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -22,14 +22,23 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.streaming.TestSuiteBase class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // use longer wait time to ensure job completion override def maxWaitTimeMillis: Int = 20000 + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // Assert that two values are equal within tolerance epsilon def assertEqual(v1: Double, v2: Double, epsilon: Double) { def errorMessage = v1.toString + " did not equal " + v2.toString @@ -62,7 +71,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model training to input stream - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.count() }) @@ -98,7 +107,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { // apply model training to input stream, storing the intermediate results // (we add a count to ensure the result is a DStream) - val ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) inputDStream.foreachRDD(x => history.append(math.abs(model.latestModel().weights(0) - 10.0))) inputDStream.count() @@ -129,7 +138,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // apply model predictions to test stream - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) // collect the output as (true, estimated) tuples @@ -156,7 +165,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { } // train and predict - val ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { + ssc = setupStreams(testInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) }) @@ -177,7 +186,7 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { val numBatches = 10 val nPoints = 100 val emptyInput = Seq.empty[Seq[LabeledPoint]] - val ssc = setupStreams(emptyInput, + ssc = setupStreams(emptyInput, (inputDStream: DStream[LabeledPoint]) => { model.trainOn(inputDStream) model.predictOnValues(inputDStream.map(x => (x.label, x.features))) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 6270137951b5a..e076fb5ea174b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -223,7 +223,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Signal the receivers to delete old block data if (WriteAheadLogUtils.enableReceiverLog(ssc.conf)) { logInfo(s"Cleanup old received batch data: $cleanupThreshTime") - endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + synchronized { + if (isTrackerStarted) { + endpoint.send(CleanupOldBlocks(cleanupThreshTime)) + } + } } } @@ -285,8 +289,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } /** Update a receiver's maximum ingestion rate */ - def sendRateUpdate(streamUID: Int, newRate: Long): Unit = { - endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + def sendRateUpdate(streamUID: Int, newRate: Long): Unit = synchronized { + if (isTrackerStarted) { + endpoint.send(UpdateReceiverRateLimit(streamUID, newRate)) + } } /** Add new blocks for the given stream */ diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index a34f23475804a..e0718f73aa13f 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -1735,6 +1735,7 @@ public Integer call(String s) throws Exception { @SuppressWarnings("unchecked") @Test public void testContextGetOrCreate() throws InterruptedException { + ssc.stop(); final SparkConf conf = new SparkConf() .setMaster("local[2]") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 08faeaa58f419..255376807c957 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -81,39 +81,41 @@ class BasicOperationsSuite extends TestSuiteBase { test("repartition (more partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(5) - val ssc = setupStreams(input, operation, 2) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 5) - assert(second.size === 5) - assert(third.size === 5) - - assert(first.flatten.toSet.equals((1 to 100).toSet) ) - assert(second.flatten.toSet.equals((101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 2)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 5) + assert(second.size === 5) + assert(third.size === 5) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("repartition (fewer partitions)") { val input = Seq(1 to 100, 101 to 200, 201 to 300) val operation = (r: DStream[Int]) => r.repartition(2) - val ssc = setupStreams(input, operation, 5) - val output = runStreamsWithPartitions(ssc, 3, 3) - assert(output.size === 3) - val first = output(0) - val second = output(1) - val third = output(2) - - assert(first.size === 2) - assert(second.size === 2) - assert(third.size === 2) - - assert(first.flatten.toSet.equals((1 to 100).toSet)) - assert(second.flatten.toSet.equals( (101 to 200).toSet)) - assert(third.flatten.toSet.equals((201 to 300).toSet)) + withStreamingContext(setupStreams(input, operation, 5)) { ssc => + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 2) + assert(second.size === 2) + assert(third.size === 2) + + assert(first.flatten.toSet.equals((1 to 100).toSet)) + assert(second.flatten.toSet.equals((101 to 200).toSet)) + assert(third.flatten.toSet.equals((201 to 300).toSet)) + } } test("groupByKey") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index b74d67c63a788..ec2852d9a0206 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -325,27 +325,31 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { } test("test track the number of input stream") { - val ssc = new StreamingContext(conf, batchDuration) + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => - class TestInputDStream extends InputDStream[String](ssc) { - def start() { } - def stop() { } - def compute(validTime: Time): Option[RDD[String]] = None - } + class TestInputDStream extends InputDStream[String](ssc) { + def start() {} - class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { - def getReceiver: Receiver[String] = null - } + def stop() {} + + def compute(validTime: Time): Option[RDD[String]] = None + } + + class TestReceiverInputDStream extends ReceiverInputDStream[String](ssc) { + def getReceiver: Receiver[String] = null + } - // Register input streams - val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) - val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) + // Register input streams + val receiverInputStreams = Array(new TestReceiverInputDStream, new TestReceiverInputDStream) + val inputStreams = Array(new TestInputDStream, new TestInputDStream, new TestInputDStream) - assert(ssc.graph.getInputStreams().length == receiverInputStreams.length + inputStreams.length) - assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) - assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) - assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) - assert(receiverInputStreams.map(_.id) === Array(0, 1)) + assert(ssc.graph.getInputStreams().length == + receiverInputStreams.length + inputStreams.length) + assert(ssc.graph.getReceiverInputStreams().length == receiverInputStreams.length) + assert(ssc.graph.getReceiverInputStreams() === receiverInputStreams) + assert(ssc.graph.getInputStreams().map(_.id) === Array.tabulate(5)(i => i)) + assert(receiverInputStreams.map(_.id) === Array(0, 1)) + } } def testFileStream(newFilesOnly: Boolean) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 6e9d4431090a2..0e64b57e0ffd8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -244,7 +244,13 @@ object MasterFailureTest extends Logging { } catch { case e: Exception => logError("Error running streaming context", e) } - if (killingThread.isAlive) killingThread.interrupt() + if (killingThread.isAlive) { + killingThread.interrupt() + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is + // stopped before running the next test. Otherwise, it's possible that we set SparkEnv.env + // to null after the next test creates the new SparkContext and fail the test. + killingThread.join() + } ssc.stop() logInfo("Has been killed = " + killed) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 4bba9691f8aa5..84a5fbb3d95eb 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -120,7 +120,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.streaming.checkpoint.directory", checkpointDirectory) - val ssc = new StreamingContext(myConf, batchDuration) + ssc = new StreamingContext(myConf, batchDuration) assert(ssc.checkpointDir != null) } @@ -369,16 +369,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } assert(exception.isInstanceOf[TestFailedDueToTimeoutException], "Did not wait for stop") + var t: Thread = null // test whether wait exits if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() ssc.awaitTermination() } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("awaitTermination after stop") { @@ -430,16 +436,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.awaitTerminationOrTimeout(500) === false) } + var t: Thread = null // test whether awaitTerminationOrTimeout() return true if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown - new Thread() { + t = new Thread() { override def run() { Thread.sleep(500) ssc.stop() } - }.start() + } + t.start() assert(ssc.awaitTerminationOrTimeout(10000) === true) } + // SparkContext.stop will set SparkEnv.env to null. We need to make sure SparkContext is stopped + // before running the next test. Otherwise, it's possible that we set SparkEnv.env to null after + // the next test creates the new SparkContext and fail the test. + t.join() } test("getOrCreate") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 4bc1dd4a30fc4..d840c349bbbc4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -36,13 +36,22 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + // To make sure that the processing start and end times in collected // information are different for successive batches override def batchDuration: Duration = Milliseconds(100) override def actuallyWait: Boolean = true test("batch info reporting") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val collector = new BatchInfoCollector ssc.addStreamingListener(collector) runStreams(ssc, input.size, input.size) @@ -107,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { } test("receiver info reporting") { - val ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) + ssc = new StreamingContext("local[2]", "test", Milliseconds(1000)) val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver) inputStream.foreachRDD(_.count) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index aff8b53f752fa..afad5f16dbc71 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -29,36 +29,40 @@ import org.apache.spark.storage.StorageLevel /** Testsuite for receiver scheduling */ class ReceiverTrackerSuite extends TestSuiteBase { val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") - val ssc = new StreamingContext(sparkConf, Milliseconds(100)) - ignore("Receiver tracker - propagates rate limit") { - object ReceiverStartedWaiter extends StreamingListener { - @volatile - var started = false + test("Receiver tracker - propagates rate limit") { + withStreamingContext(new StreamingContext(sparkConf, Milliseconds(100))) { ssc => + object ReceiverStartedWaiter extends StreamingListener { + @volatile + var started = false - override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { - started = true + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + started = true + } } - } - - ssc.addStreamingListener(ReceiverStartedWaiter) - ssc.scheduler.listenerBus.start(ssc.sc) - SingletonTestRateReceiver.reset() - - val newRateLimit = 100L - val inputDStream = new RateLimitInputDStream(ssc) - val tracker = new ReceiverTracker(ssc) - tracker.start() - // we wait until the Receiver has registered with the tracker, - // otherwise our rate update is lost - eventually(timeout(5 seconds)) { - assert(ReceiverStartedWaiter.started) - } - tracker.sendRateUpdate(inputDStream.id, newRateLimit) - // this is an async message, we need to wait a bit for it to be processed - eventually(timeout(3 seconds)) { - assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + ssc.addStreamingListener(ReceiverStartedWaiter) + ssc.scheduler.listenerBus.start(ssc.sc) + SingletonTestRateReceiver.reset() + + val newRateLimit = 100L + val inputDStream = new RateLimitInputDStream(ssc) + val tracker = new ReceiverTracker(ssc) + tracker.start() + try { + // we wait until the Receiver has registered with the tracker, + // otherwise our rate update is lost + eventually(timeout(5 seconds)) { + assert(ReceiverStartedWaiter.started) + } + tracker.sendRateUpdate(inputDStream.id, newRateLimit) + // this is an async message, we need to wait a bit for it to be processed + eventually(timeout(3 seconds)) { + assert(inputDStream.getCurrentRateLimit.get === newRateLimit) + } + } finally { + tracker.stop(false) + } } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index 0891309f956d2..995f1197ccdfd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -22,15 +22,24 @@ import java.util.Properties import org.scalatest.Matchers import org.apache.spark.scheduler.SparkListenerJobStart +import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.{Duration, Time, Milliseconds, TestSuiteBase} class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq val operation = (d: DStream[Int]) => d.map(x => x) + var ssc: StreamingContext = _ + + override def afterFunction() { + super.afterFunction() + if (ssc != null) { + ssc.stop() + } + } + private def createJobStart( batchTime: Time, outputOpId: Int, jobId: Int): SparkListenerJobStart = { val properties = new Properties() @@ -46,7 +55,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { test("onBatchSubmitted, onBatchStarted, onBatchCompleted, " + "onReceiverStarted, onReceiverError, onReceiverStopped") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val streamIdToInputInfo = Map( @@ -141,7 +150,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("Remove the old completed batches when exceeding the limit") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -158,7 +167,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("out-of-order onJobStart and onBatchXXX") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) val listener = new StreamingJobProgressListener(ssc) @@ -209,7 +218,7 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { } test("detect memory leak") { - val ssc = setupStreams(input, operation) + ssc = setupStreams(input, operation) val listener = new StreamingJobProgressListener(ssc) val limit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 1000) From 7f7a319c4ce07f07a6bd68100cf0a4f1da66269e Mon Sep 17 00:00:00 2001 From: martinzapletal Date: Thu, 30 Jul 2015 15:57:14 -0700 Subject: [PATCH 178/219] [SPARK-8671] [ML] Added isotonic regression to the pipeline API. Author: martinzapletal Closes #7517 from zapletal-martin/SPARK-8671-isotonic-regression-api and squashes the following commits: 8c435c1 [martinzapletal] Review https://github.com/apache/spark/pull/7517 feedback update. bebbb86 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-8671-isotonic-regression-api b68efc0 [martinzapletal] Added tests for param validation. 07c12bd [martinzapletal] Comments and refactoring. 834fcf7 [martinzapletal] Merge remote-tracking branch 'upstream/master' into SPARK-8671-isotonic-regression-api b611fee [martinzapletal] SPARK-8671. Added first version of isotonic regression to pipeline API --- .../ml/regression/IsotonicRegression.scala | 144 +++++++++++++++++ .../regression/IsotonicRegressionSuite.scala | 148 ++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala new file mode 100644 index 0000000000000..4ece8cf8cf0b6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.param.{Param, ParamMap, BooleanParam} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} +import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DoubleType, DataType} +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.storage.StorageLevel + +/** + * Params for isotonic regression. + */ +private[regression] trait IsotonicRegressionParams extends PredictorParams { + + /** + * Param for weight column name. + * TODO: Move weightCol to sharedParams. + * + * @group param + */ + final val weightCol: Param[String] = + new Param[String](this, "weightCol", "weight column name") + + /** @group getParam */ + final def getWeightCol: String = $(weightCol) + + /** + * Param for isotonic parameter. + * Isotonic (increasing) or antitonic (decreasing) sequence. + * @group param + */ + final val isotonic: BooleanParam = + new BooleanParam(this, "isotonic", "isotonic (increasing) or antitonic (decreasing) sequence") + + /** @group getParam */ + final def getIsotonicParam: Boolean = $(isotonic) +} + +/** + * :: Experimental :: + * Isotonic regression. + * + * Currently implemented using parallelized pool adjacent violators algorithm. + * Only univariate (single feature) algorithm supported. + * + * Uses [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +@Experimental +class IsotonicRegression(override val uid: String) + extends Regressor[Double, IsotonicRegression, IsotonicRegressionModel] + with IsotonicRegressionParams { + + def this() = this(Identifiable.randomUID("isoReg")) + + /** + * Set the isotonic parameter. + * Default is true. + * @group setParam + */ + def setIsotonicParam(value: Boolean): this.type = set(isotonic, value) + setDefault(isotonic -> true) + + /** + * Set weight column param. + * Default is weight. + * @group setParam + */ + def setWeightParam(value: String): this.type = set(weightCol, value) + setDefault(weightCol -> "weight") + + override private[ml] def featuresDataType: DataType = DoubleType + + override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) + + private[this] def extractWeightedLabeledPoints( + dataset: DataFrame): RDD[(Double, Double, Double)] = { + + dataset.select($(labelCol), $(featuresCol), $(weightCol)) + .map { case Row(label: Double, features: Double, weights: Double) => + (label, features, weights) + } + } + + override protected def train(dataset: DataFrame): IsotonicRegressionModel = { + SchemaUtils.checkColumnType(dataset.schema, $(weightCol), DoubleType) + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val instances = extractWeightedLabeledPoints(dataset) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + + val isotonicRegression = new MLlibIsotonicRegression().setIsotonic($(isotonic)) + val parentModel = isotonicRegression.run(instances) + + new IsotonicRegressionModel(uid, parentModel) + } +} + +/** + * :: Experimental :: + * Model fitted by IsotonicRegression. + * Predicts using a piecewise linear function. + * + * For detailed rules see [[org.apache.spark.mllib.regression.IsotonicRegressionModel.predict()]]. + * + * @param parentModel A [[org.apache.spark.mllib.regression.IsotonicRegressionModel]] + * model trained by [[org.apache.spark.mllib.regression.IsotonicRegression]]. + */ +class IsotonicRegressionModel private[ml] ( + override val uid: String, + private[ml] val parentModel: MLlibIsotonicRegressionModel) + extends RegressionModel[Double, IsotonicRegressionModel] + with IsotonicRegressionParams { + + override def featuresDataType: DataType = DoubleType + + override protected def predict(features: Double): Double = { + parentModel.predict(features) + } + + override def copy(extra: ParamMap): IsotonicRegressionModel = { + copyValues(new IsotonicRegressionModel(uid, parentModel), extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala new file mode 100644 index 0000000000000..66e4b170bae80 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row} + +class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + private val schema = StructType( + Array( + StructField("label", DoubleType), + StructField("features", DoubleType), + StructField("weight", DoubleType))) + + private val predictionSchema = StructType(Array(StructField("features", DoubleType))) + + private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { + val data = Seq.tabulate(labels.size)(i => Row(labels(i), i.toDouble, 1d)) + val parallelData = sc.parallelize(data) + + sqlContext.createDataFrame(parallelData, schema) + } + + private def generatePredictionInput(features: Seq[Double]): DataFrame = { + val data = Seq.tabulate(features.size)(i => Row(features(i))) + + val parallelData = sc.parallelize(data) + sqlContext.createDataFrame(parallelData, predictionSchema) + } + + test("isotonic regression predictions") { + val dataset = generateIsotonicInput(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)) + val trainer = new IsotonicRegression().setIsotonicParam(true) + + val model = trainer.fit(dataset) + + val predictions = model + .transform(dataset) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + + assert(model.parentModel.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.parentModel.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.parentModel.isotonic) + } + + test("antitonic regression predictions") { + val dataset = generateIsotonicInput(Seq(7, 5, 3, 5, 1)) + val trainer = new IsotonicRegression().setIsotonicParam(false) + + val model = trainer.fit(dataset) + val features = generatePredictionInput(Seq(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0)) + + val predictions = model + .transform(features) + .select("prediction").map { + case Row(pred) => pred + }.collect() + + assert(predictions === Array(7, 7, 6, 5.5, 5, 4, 1)) + } + + test("params validation") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression + ParamsSuite.checkParams(ir) + val model = ir.fit(dataset) + ParamsSuite.checkParams(model) + } + + test("default params") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + val ir = new IsotonicRegression() + assert(ir.getLabelCol === "label") + assert(ir.getFeaturesCol === "features") + assert(ir.getWeightCol === "weight") + assert(ir.getPredictionCol === "prediction") + assert(ir.getIsotonicParam === true) + + val model = ir.fit(dataset) + model.transform(dataset) + .select("label", "features", "prediction", "weight") + .collect() + + assert(model.getLabelCol === "label") + assert(model.getFeaturesCol === "features") + assert(model.getWeightCol === "weight") + assert(model.getPredictionCol === "prediction") + assert(model.getIsotonicParam === true) + assert(model.hasParent) + } + + test("set parameters") { + val isotonicRegression = new IsotonicRegression() + .setIsotonicParam(false) + .setWeightParam("w") + .setFeaturesCol("f") + .setLabelCol("l") + .setPredictionCol("p") + + assert(isotonicRegression.getIsotonicParam === false) + assert(isotonicRegression.getWeightCol === "w") + assert(isotonicRegression.getFeaturesCol === "f") + assert(isotonicRegression.getLabelCol === "l") + assert(isotonicRegression.getPredictionCol === "p") + } + + test("missing column") { + val dataset = generateIsotonicInput(Seq(1, 2, 3)) + + intercept[IllegalArgumentException] { + new IsotonicRegression().setWeightParam("w").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setFeaturesCol("f").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().setLabelCol("l").fit(dataset) + } + + intercept[IllegalArgumentException] { + new IsotonicRegression().fit(dataset).setFeaturesCol("f").transform(dataset) + } + } +} From be7be6d4c7d978c20e601d1f5f56ecb3479814cb Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 30 Jul 2015 16:04:23 -0700 Subject: [PATCH 179/219] [SPARK-6684] [MLLIB] [ML] Add checkpointing to GBTs Add checkpointing to GradientBoostedTrees, GBTClassifier, GBTRegressor CC: mengxr Author: Joseph K. Bradley Closes #7804 from jkbradley/gbt-checkpoint3 and squashes the following commits: 3fbd7ba [Joseph K. Bradley] tiny fix b3e160c [Joseph K. Bradley] unset checkpoint dir after test 9cc3a04 [Joseph K. Bradley] added checkpointing to GBTs --- .../spark/mllib/clustering/LDAOptimizer.scala | 1 + .../mllib/tree/GradientBoostedTrees.scala | 48 +++++------ .../tree/configuration/BoostingStrategy.scala | 3 +- .../classification/GBTClassifierSuite.scala | 20 +++++ .../ml/regression/GBTRegressorSuite.scala | 20 ++++- .../tree/GradientBoostedTreesSuite.scala | 79 +++++++++++-------- 6 files changed, 114 insertions(+), 57 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 9dbec41efeada..d6f8b29a43dfd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -144,6 +144,7 @@ final class EMLDAOptimizer extends LDAOptimizer { this.checkpointInterval = lda.getCheckpointInterval this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( checkpointInterval, graph.vertices.sparkContext) + this.graphCheckpointer.update(this.graph) this.globalTopicTotals = computeGlobalTopicTotals() this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index a835f96d5d0e3..9ce6faa137c41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.Algo._ @@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging { false } + // Prepare periodic checkpointers + val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( + treeStrategy.getCheckpointInterval, input.sparkContext) + timer.stop("init") logDebug("##########") logDebug("Building tree 0") logDebug("##########") - var data = input // Initialize tree timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy).run(data) + val firstTreeModel = new DecisionTree(treeStrategy).run(input) val firstTreeWeight = 1.0 baseLearners(0) = firstTreeModel baseLearnerWeights(0) = firstTreeWeight var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) // Note: A model of type regression is used since we require raw prediction @@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging { var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) + if (validate) validatePredErrorCheckpointer.update(validatePredError) var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 var bestM = 1 - // pseudo-residual for second iteration - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } - var m = 1 - while (m < numIterations) { + var doneLearning = false + while (m < numIterations && !doneLearning) { + // Update data with pseudo-residuals + val data = predError.zip(input).map { case ((pred, _), point) => + LabeledPoint(-loss.gradient(pred, point.label), point.features) + } + timer.start(s"building tree $m") logDebug("###################################################") logDebug("Gradient boosting tree iteration " + m) logDebug("###################################################") val model = new DecisionTree(treeStrategy).run(data) timer.stop(s"building tree $m") - // Create partial model + // Update partial model baseLearners(m) = model // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. // Technically, the weight should be optimized for the particular loss. // However, the behavior should be reasonable, though not optimal. baseLearnerWeights(m) = learningRate - // Note: A model of type regression is used since we require raw prediction - val partialModel = new GradientBoostedTreesModel( - Regression, baseLearners.slice(0, m + 1), - baseLearnerWeights.slice(0, m + 1)) predError = GradientBoostedTreesModel.updatePredictionError( input, predError, baseLearnerWeights(m), baseLearners(m), loss) + predErrorCheckpointer.update(predError) logDebug("error of gbt = " + predError.values.mean()) if (validate) { @@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging { validatePredError = GradientBoostedTreesModel.updatePredictionError( validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) + validatePredErrorCheckpointer.update(validatePredError) val currentValidateError = validatePredError.values.mean() if (bestValidateError - currentValidateError < validationTol) { - return new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, - baseLearners.slice(0, bestM), - baseLearnerWeights.slice(0, bestM)) + doneLearning = true } else if (currentValidateError < bestValidateError) { - bestValidateError = currentValidateError - bestM = m + 1 + bestValidateError = currentValidateError + bestM = m + 1 } } - // Update data with pseudo-residuals - data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } m += 1 } @@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging { logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + predErrorCheckpointer.deleteAllCheckpoints() + validatePredErrorCheckpointer.deleteAllCheckpoints() if (persistedInput) input.unpersist() if (validate) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index 2d6b01524ff3d..9fd30c9b56319 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * learning rate should be between in the interval (0, 1] * @param validationTol Useful when runWithValidation is used. If the error rate on the * validation input between two iterations is less than the validationTol - * then stop. Ignored when [[run]] is used. + * then stop. Ignored when + * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. */ @Experimental case class BoostingStrategy( diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 82c345491bb3c..a7bc77965fefd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2) + val gbt = new GBTClassifier() + .setMaxDepth(2) + .setLossType("logistic") + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 9682edcd9ba84..dbdce0c9dea54 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.DataFrame +import org.apache.spark.util.Utils /** @@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { assert(predictions.min() < -1) } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val df = sqlContext.createDataFrame(data) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(5) + .setStepSize(0.1) + .setCheckpointInterval(2) + val model = gbt.fit(df) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 2521b3342181a..6fc9e8df621df 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) - (algos zip losses) map { - case (algo, loss) => { - val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy) - .runWithValidation(trainRdd, validateRdd) - val numTrees = gbtValidate.numTrees - assert(numTrees !== numIterations) - - // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) - val (errorWithoutValidation, errorWithValidation) = { - if (algo == Classification) { - val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) - } else { - (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) - } - } - assert(errorWithValidation <= errorWithoutValidation) - - // Test that results from evaluateEachIteration comply with runWithValidation. - // Note that convergenceTol is set to 0.0 - val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) - assert(evaluationArray.length === numIterations) - assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) - var i = 1 - while (i < numTrees) { - assert(evaluationArray(i) <= evaluationArray(i - 1)) - i += 1 + algos.zip(losses).foreach { case (algo, loss) => + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + val numTrees = gbtValidate.numTrees + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) } } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } } } + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + sc.setCheckpointDir(path) + + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, checkpointInterval = 2) + val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + sc.checkpointDir = None + Utils.deleteRecursively(tempDir) + } + } private object GradientBoostedTreesSuite { From e7905a9395c1a002f50bab29e16a729e14d4ed6f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 30 Jul 2015 16:15:43 -0700 Subject: [PATCH 180/219] [SPARK-9463] [ML] Expose model coefficients with names in SparkR RFormula Preview: ``` > summary(m) features coefficients 1 (Intercept) 1.6765001 2 Sepal_Length 0.3498801 3 Species.versicolor -0.9833885 4 Species.virginica -1.0075104 ``` Design doc from umbrella task: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit cc mengxr Author: Eric Liang Closes #7771 from ericl/summary and squashes the following commits: ccd54c3 [Eric Liang] second pass a5ca93b [Eric Liang] comments 2772111 [Eric Liang] clean up 70483ef [Eric Liang] fix test 7c247d4 [Eric Liang] Merge branch 'master' into summary 3c55024 [Eric Liang] working 8c539aa [Eric Liang] first pass --- R/pkg/NAMESPACE | 3 ++- R/pkg/R/mllib.R | 26 ++++++++++++++++++ R/pkg/inst/tests/test_mllib.R | 11 ++++++++ .../spark/ml/feature/OneHotEncoder.scala | 12 ++++----- .../apache/spark/ml/feature/RFormula.scala | 12 ++++++++- .../apache/spark/ml/r/SparkRWrappers.scala | 27 +++++++++++++++++-- .../ml/regression/LinearRegression.scala | 8 ++++-- .../spark/ml/feature/OneHotEncoderSuite.scala | 8 +++--- .../spark/ml/feature/RFormulaSuite.scala | 18 +++++++++++++ 9 files changed, 108 insertions(+), 17 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 7f7a8a2e4de24..a329e14f25aeb 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -12,7 +12,8 @@ export("print.jobj") # MLlib integration exportMethods("glm", - "predict") + "predict", + "summary") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 6a8bacaa552c6..efddcc1d8d71c 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -71,3 +71,29 @@ setMethod("predict", signature(object = "PipelineModel"), function(object, newData) { return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) }) + +#' Get the summary of a model +#' +#' Returns the summary of a model produced by glm(), similarly to R's summary(). +#' +#' @param model A fitted MLlib model +#' @return a list with a 'coefficient' component, which is the matrix of coefficients. See +#' summary.glm for more information. +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#'} +setMethod("summary", signature(object = "PipelineModel"), + function(object) { + features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelFeatures", object@model) + weights <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelWeights", object@model) + coefficients <- as.matrix(unlist(weights)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 3bef69324770a..f272de78ad4a6 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -48,3 +48,14 @@ test_that("dot minus and intercept vs native glm", { rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) + +test_that("summary coefficients match with native glm", { + training <- createDataFrame(sqlContext, iris) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + coefs <- as.vector(stats$coefficients) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) + expect_true(all(abs(rCoefs - coefs) < 1e-6)) + expect_true(all( + as.character(stats$features) == + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 3825942795645..9c60d4084ec46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) @@ -79,17 +78,17 @@ class OneHotEncoder(override val uid: String) extends Transformer val outputAttrNames: Option[Array[String]] = inputAttr match { case nominal: NominalAttribute => if (nominal.values.isDefined) { - nominal.values.map(_.map(v => inputColName + is + v)) + nominal.values } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(i => inputColName + is + i)) + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) } else { None } case binary: BinaryAttribute => if (binary.values.isDefined) { - binary.values.map(_.map(v => inputColName + is + v)) + binary.values } else { - Some(Array.tabulate(2)(i => inputColName + is + i)) + Some(Array.tabulate(2)(_.toString)) } case _: NumericAttribute => throw new RuntimeException( @@ -123,7 +122,6 @@ class OneHotEncoder(override val uid: String) extends Transformer override def transform(dataset: DataFrame): DataFrame = { // schema transformation - val is = "_is_" val inputColName = $(inputCol) val outputColName = $(outputCol) val shouldDropLast = $(dropLast) @@ -142,7 +140,7 @@ class OneHotEncoder(override val uid: String) extends Transformer math.max(m0, m1) } ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(i => inputColName + is + i) + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames val outputAttrs: Array[Attribute] = filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 0b428d278d908..d1726917e4517 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers @@ -91,11 +92,20 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() + val takenNames = mutable.Set(dataset.columns: _*) val encodedTerms = resolvedFormula.terms.map { term => dataset.schema(term) match { case column if column.dataType == StringType => val indexCol = term + "_idx_" + uid - val encodedCol = term + "_onehot_" + uid + val encodedCol = { + var tmp = term + while (takenNames.contains(tmp)) { + tmp += "_" + } + tmp + } + takenNames.add(indexCol) + takenNames.add(encodedCol) encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) tempColumns += indexCol diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 9f70592ccad7e..f5a022c31ed90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.api.r +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame @@ -44,4 +45,26 @@ private[r] object SparkRWrappers { val pipeline = new Pipeline().setStages(Array(formula, estimator)) pipeline.fit(df) } + + def getModelWeights(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => + Array(m.intercept) ++ m.weights.toArray + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No weights available for LogisticRegressionModel") // SPARK-9492 + } + } + + def getModelFeatures(model: PipelineModel): Array[String] = { + model.stages.last match { + case m: LinearRegressionModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + case _: LogisticRegressionModel => + throw new UnsupportedOperationException( + "No features names available for LogisticRegressionModel") // SPARK-9492 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 89718e0f3e15a..3b85ba001b128 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -36,6 +36,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructField import org.apache.spark.storage.StorageLevel import org.apache.spark.util.StatCounter @@ -146,9 +147,10 @@ class LinearRegression(override val uid: String) val model = new LinearRegressionModel(uid, weights, intercept) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } @@ -221,9 +223,10 @@ class LinearRegression(override val uid: String) val model = copyValues(new LinearRegressionModel(uid, weights, intercept)) val trainingSummary = new LinearRegressionTrainingSummary( - model.transform(dataset).select($(predictionCol), $(labelCol)), + model.transform(dataset), $(predictionCol), $(labelCol), + $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -300,6 +303,7 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + val featuresCol: String, val objectiveHistory: Array[Double]) extends LinearRegressionSummary(predictions, predictionCol, labelCol) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 65846a846b7b4..321eeb843941c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -86,8 +86,8 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("size_is_small").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("size_is_medium").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) } test("input column without ML attribute") { @@ -98,7 +98,7 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { val output = encoder.transform(df) val group = AttributeGroup.fromStructField(output.schema("encoded")) assert(group.size === 2) - assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("index_is_0").withIndex(0)) - assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("index_is_1").withIndex(1)) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 8148c553e9051..6aed3243afce8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -105,4 +106,21 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(result.schema.toString == resultSchema.toString) assert(result.collect() === expected.collect()) } + + test("attribute generation") { + val formula = new RFormula().setFormula("id ~ a + b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array( + new BinaryAttribute(Some("a__bar"), Some(1)), + new BinaryAttribute(Some("a__foo"), Some(2)), + new NumericAttribute(Some("b"), Some(3)))) + assert(attrs === expectedAttrs) + } } From 157840d1b14502a4f25cff53633c927998c6ada1 Mon Sep 17 00:00:00 2001 From: Hossein Date: Thu, 30 Jul 2015 16:16:17 -0700 Subject: [PATCH 181/219] [SPARK-8742] [SPARKR] Improve SparkR error messages for DataFrame API This patch improves SparkR error message reporting, especially with DataFrame API. When there is a user error (e.g., malformed SQL query), the message of the cause is sent back through the RPC and the R client reads it and returns it back to user. cc shivaram Author: Hossein Closes #7742 from falaki/SPARK-8742 and squashes the following commits: 4f643c9 [Hossein] Not logging exceptions in RBackendHandler 4a8005c [Hossein] Returning stack track of causing exception from RBackendHandler 5cf17f0 [Hossein] Adding unit test for error messages from SQLContext 2af75d5 [Hossein] Reading error message in case of failure and stoping with that message f479c99 [Hossein] Wrting exception cause message in JVM --- R/pkg/R/backend.R | 4 +++- R/pkg/inst/tests/test_sparkSQL.R | 5 +++++ .../scala/org/apache/spark/api/r/RBackendHandler.scala | 10 ++++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R index 2fb6fae55f28c..49162838b8d1a 100644 --- a/R/pkg/R/backend.R +++ b/R/pkg/R/backend.R @@ -110,6 +110,8 @@ invokeJava <- function(isStatic, objId, methodName, ...) { # TODO: check the status code to output error information returnStatus <- readInt(conn) - stopifnot(returnStatus == 0) + if (returnStatus != 0) { + stop(readString(conn)) + } readObject(conn) } diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d5db97248c770..61c8a7ec7d837 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1002,6 +1002,11 @@ test_that("crosstab() on a DataFrame", { expect_identical(expected, ordered) }) +test_that("SQL error message is returned from JVM", { + retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) + expect_equal(grepl("Table Not Found: blah", retError), TRUE) +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index a5de10fe89c42..14dac4ed28ce3 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -69,8 +69,11 @@ private[r] class RBackendHandler(server: RBackend) case e: Exception => logError(s"Removing $objId failed", e) writeInt(dos, -1) + writeString(dos, s"Removing $objId failed: ${e.getMessage}") } - case _ => dos.writeInt(-1) + case _ => + dos.writeInt(-1) + writeString(dos, s"Error: unknown method $methodName") } } else { handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos) @@ -146,8 +149,11 @@ private[r] class RBackendHandler(server: RBackend) } } catch { case e: Exception => - logError(s"$methodName on $objId failed", e) + logError(s"$methodName on $objId failed") writeInt(dos, -1) + // Writing the error message of the cause for the exception. This will be returned + // to user in the R process. + writeString(dos, Utils.exceptionString(e.getCause)) } } From 04c8409107710fc9a625ee513d68c149745539f3 Mon Sep 17 00:00:00 2001 From: Calvin Jia Date: Thu, 30 Jul 2015 16:32:40 -0700 Subject: [PATCH 182/219] [SPARK-9199] [CORE] Update Tachyon dependency from 0.6.4 -> 0.7.0 No new dependencies are added. The exclusion changes are due to the change in tachyon-client 0.7.0's project structure. There is no client side API change in Tachyon 0.7.0 so no code changes are required. Author: Calvin Jia Closes #7577 from calvinjia/SPARK-9199 and squashes the following commits: 4e81e40 [Calvin Jia] Update Tachyon dependency from 0.6.4 -> 0.7.0 --- core/pom.xml | 34 +++++----------------------------- make-distribution.sh | 2 +- 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 6fa87ec6a24af..202678779150b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -286,7 +286,7 @@ org.tachyonproject tachyon-client - 0.6.4 + 0.7.0 org.apache.hadoop @@ -297,36 +297,12 @@ curator-recipes - org.eclipse.jetty - jetty-jsp + org.tachyonproject + tachyon-underfs-glusterfs - org.eclipse.jetty - jetty-webapp - - - org.eclipse.jetty - jetty-server - - - org.eclipse.jetty - jetty-servlet - - - junit - junit - - - org.powermock - powermock-module-junit4 - - - org.powermock - powermock-api-mockito - - - org.apache.curator - curator-test + org.tachyonproject + tachyon-underfs-s3 diff --git a/make-distribution.sh b/make-distribution.sh index cac7032bb2e87..4789b0e09cc8a 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -33,7 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false -TACHYON_VERSION="0.6.4" +TACHYON_VERSION="0.7.0" TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" From 1afdeb7b458f86e2641f062fb9ddc00e9c5c7531 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 30 Jul 2015 16:44:02 -0700 Subject: [PATCH 183/219] [STREAMING] [TEST] [HOTFIX] Fixed Kinesis test to not throw weird errors when Kinesis tests are enabled without AWS keys If Kinesis tests are enabled by env ENABLE_KINESIS_TESTS = 1 but no AWS credentials are found, the desired behavior is the fail the test using with ``` Exception encountered when attempting to run a suite with class name: org.apache.spark.streaming.kinesis.KinesisBackedBlockRDDSuite *** ABORTED *** (3 seconds, 5 milliseconds) [info] java.lang.Exception: Kinesis tests enabled, but could get not AWS credentials ``` Instead KinesisStreamSuite fails with ``` [info] - basic operation *** FAILED *** (3 seconds, 35 milliseconds) [info] java.lang.IllegalArgumentException: requirement failed: Stream not yet created, call createStream() to create one [info] at scala.Predef$.require(Predef.scala:233) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils.streamName(KinesisTestUtils.scala:77) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils$$anonfun$deleteStream$1.apply(KinesisTestUtils.scala:150) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils$$anonfun$deleteStream$1.apply(KinesisTestUtils.scala:150) [info] at org.apache.spark.Logging$class.logWarning(Logging.scala:71) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils.logWarning(KinesisTestUtils.scala:39) [info] at org.apache.spark.streaming.kinesis.KinesisTestUtils.deleteStream(KinesisTestUtils.scala:150) [info] at org.apache.spark.streaming.kinesis.KinesisStreamSuite$$anonfun$3.apply$mcV$sp(KinesisStreamSuite.scala:111) [info] at org.apache.spark.streaming.kinesis.KinesisStreamSuite$$anonfun$3.apply(KinesisStreamSuite.scala:86) [info] at org.apache.spark.streaming.kinesis.KinesisStreamSuite$$anonfun$3.apply(KinesisStreamSuite.scala:86) ``` This is because attempting to delete a non-existent Kinesis stream throws uncaught exception. This PR fixes it. Author: Tathagata Das Closes #7809 from tdas/kinesis-test-hotfix and squashes the following commits: 7c372e6 [Tathagata Das] Fixed test --- .../streaming/kinesis/KinesisTestUtils.scala | 27 ++++++++++--------- .../kinesis/KinesisStreamSuite.scala | 4 +-- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 0ff1b7ed0fd90..ca39358b75cb6 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -53,6 +53,8 @@ private class KinesisTestUtils( @volatile private var streamCreated = false + + @volatile private var _streamName: String = _ private lazy val kinesisClient = { @@ -115,21 +117,9 @@ private class KinesisTestUtils( shardIdToSeqNumbers.toMap } - def describeStream(streamNameToDescribe: String = streamName): Option[StreamDescription] = { - try { - val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) - val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() - Some(desc) - } catch { - case rnfe: ResourceNotFoundException => - None - } - } - def deleteStream(): Unit = { try { - if (describeStream().nonEmpty) { - val deleteStreamRequest = new DeleteStreamRequest() + if (streamCreated) { kinesisClient.deleteStream(streamName) } } catch { @@ -149,6 +139,17 @@ private class KinesisTestUtils( } } + private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = { + try { + val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) + val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + Some(desc) + } catch { + case rnfe: ResourceNotFoundException => + None + } + } + private def findNonExistentStreamName(): String = { var testStreamName: String = null do { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index f9c952b9468bb..b88c9c6478d56 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -88,11 +88,11 @@ class KinesisStreamSuite extends KinesisFunSuite try { kinesisTestUtils.createStream() ssc = new StreamingContext(sc, Seconds(1)) - val aWSCredentials = KinesisTestUtils.getAWSCredentials() + val awsCredentials = KinesisTestUtils.getAWSCredentials() val stream = KinesisUtils.createStream(ssc, kinesisAppName, kinesisTestUtils.streamName, kinesisTestUtils.endpointUrl, kinesisTestUtils.regionName, InitialPositionInStream.LATEST, Seconds(10), StorageLevel.MEMORY_ONLY, - aWSCredentials.getAWSAccessKeyId, aWSCredentials.getAWSSecretKey) + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => From ca71cc8c8b2d64b7756ae697c06876cd18b536dc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 16:57:38 -0700 Subject: [PATCH 184/219] [SPARK-9408] [PYSPARK] [MLLIB] Refactor linalg.py to /linalg This is based on MechCoder 's PR https://github.com/apache/spark/pull/7731. Hopefully it could pass tests. MechCoder I tried to make minimal changes. If this passes Jenkins, we can merge this one first and then try to move `__init__.py` to `local.py` in a separate PR. Closes #7731 Author: Xiangrui Meng Closes #7746 from mengxr/SPARK-9408 and squashes the following commits: 0e05a3b [Xiangrui Meng] merge master 1135551 [Xiangrui Meng] add a comment for str(...) c48cae0 [Xiangrui Meng] update tests 173a805 [Xiangrui Meng] move linalg.py to linalg/__init__.py --- dev/sparktestsupport/modules.py | 2 +- python/pyspark/mllib/{linalg.py => linalg/__init__.py} | 0 python/pyspark/sql/types.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename python/pyspark/mllib/{linalg.py => linalg/__init__.py} (100%) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 030d982e99106..44600cb9523c1 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -323,7 +323,7 @@ def contains_file(self, filename): "pyspark.mllib.evaluation", "pyspark.mllib.feature", "pyspark.mllib.fpm", - "pyspark.mllib.linalg", + "pyspark.mllib.linalg.__init__", "pyspark.mllib.random", "pyspark.mllib.recommendation", "pyspark.mllib.regression", diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg/__init__.py similarity index 100% rename from python/pyspark/mllib/linalg.py rename to python/pyspark/mllib/linalg/__init__.py diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0976aea72c034..6f74b7162f7cc 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -648,7 +648,7 @@ def jsonValue(self): @classmethod def fromJson(cls, json): - pyUDT = str(json["pyClass"]) + pyUDT = str(json["pyClass"]) # convert unicode to str split = pyUDT.rfind(".") pyModule = pyUDT[:split] pyClass = pyUDT[split+1:] From df32669514afc0223ecdeca30fbfbe0b40baef3a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 30 Jul 2015 17:16:03 -0700 Subject: [PATCH 185/219] [SPARK-7157][SQL] add sampleBy to DataFrame This was previously committed but then reverted due to test failures (see #6769). Author: Xiangrui Meng Closes #7755 from rxin/SPARK-7157 and squashes the following commits: fbf9044 [Xiangrui Meng] fix python test 542bd37 [Xiangrui Meng] update test 604fe6d [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 f051afd [Xiangrui Meng] use udf instead of building expression f4e9425 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 8fb990b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7157 103beb3 [Xiangrui Meng] add Java-friendly sampleBy 991f26f [Xiangrui Meng] fix seed 4a14834 [Xiangrui Meng] move sampleBy to stat 832f7cc [Xiangrui Meng] add sampleBy to DataFrame --- python/pyspark/sql/dataframe.py | 41 ++++++++++++++++++ .../spark/sql/DataFrameStatFunctions.scala | 42 +++++++++++++++++++ .../apache/spark/sql/JavaDataFrameSuite.java | 9 ++++ .../apache/spark/sql/DataFrameStatSuite.scala | 12 +++++- 4 files changed, 102 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d76e051bd73a1..0f3480c239187 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -441,6 +441,42 @@ def sample(self, withReplacement, fraction, seed=None): rdd = self._jdf.sample(withReplacement, fraction, long(seed)) return DataFrame(rdd, self.sql_ctx) + @since(1.5) + def sampleBy(self, col, fractions, seed=None): + """ + Returns a stratified sample without replacement based on the + fraction given on each stratum. + + :param col: column that defines strata + :param fractions: + sampling fraction for each stratum. If a stratum is not + specified, we treat its fraction as zero. + :param seed: random seed + :return: a new DataFrame that represents the stratified sample + + >>> from pyspark.sql.functions import col + >>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key")) + >>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0) + >>> sampled.groupBy("key").count().orderBy("key").show() + +---+-----+ + |key|count| + +---+-----+ + | 0| 3| + | 1| 8| + +---+-----+ + + """ + if not isinstance(col, str): + raise ValueError("col must be a string, but got %r" % type(col)) + if not isinstance(fractions, dict): + raise ValueError("fractions must be a dict but got %r" % type(fractions)) + for k, v in fractions.items(): + if not isinstance(k, (float, int, long, basestring)): + raise ValueError("key must be float, int, long, or string, but got %r" % type(k)) + fractions[k] = float(v) + seed = seed if seed is not None else random.randint(0, sys.maxsize) + return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx) + @since(1.4) def randomSplit(self, weights, seed=None): """Randomly splits this :class:`DataFrame` with the provided weights. @@ -1314,6 +1350,11 @@ def freqItems(self, cols, support=None): freqItems.__doc__ = DataFrame.freqItems.__doc__ + def sampleBy(self, col, fractions, seed=None): + return self.df.sampleBy(col, fractions, seed) + + sampleBy.__doc__ = DataFrame.sampleBy.__doc__ + def _test(): import doctest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 4ec58082e7aef..2e68e358f2f1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import java.{util => ju, lang => jl} + +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.sql.execution.stat._ @@ -166,4 +170,42 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def freqItems(cols: Seq[String]): DataFrame = { FrequentItems.singlePassFreqItems(df, cols, 0.01) } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DataFrame = { + require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), + s"Fractions must be in [0, 1], but got $fractions.") + import org.apache.spark.sql.functions.{rand, udf} + val c = Column(col) + val r = rand(seed) + val f = udf { (stratum: Any, x: Double) => + x < fractions.getOrElse(stratum.asInstanceOf[T], 0.0) + } + df.filter(f(c, r)) + } + + /** + * Returns a stratified sample without replacement based on the fraction given on each stratum. + * @param col column that defines strata + * @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat + * its fraction as zero. + * @param seed random seed + * @tparam T stratum type + * @return a new [[DataFrame]] that represents the stratified sample + * + * @since 1.5.0 + */ + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = { + sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 9e61d06f4036e..2c669bb59a0b5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -226,4 +226,13 @@ public void testCovariance() { Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1e-6); } + + @Test + public void testSampleBy() { + DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key")); + DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); + Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); + Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)}; + Assert.assertArrayEquals(expected, actual); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 7ba4ba73e0cc9..07a675e64f527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -21,9 +21,9 @@ import java.util.Random import org.scalatest.Matchers._ -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.functions.col -class DataFrameStatSuite extends SparkFunSuite { +class DataFrameStatSuite extends QueryTest { private val sqlCtx = org.apache.spark.sql.test.TestSQLContext import sqlCtx.implicits._ @@ -130,4 +130,12 @@ class DataFrameStatSuite extends SparkFunSuite { val items2 = singleColResults.collect().head items2.getSeq[Double](0) should contain (-1.0) } + + test("sampleBy") { + val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key")) + val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) + checkAnswer( + sampled.groupBy("key").count().orderBy("key"), + Seq(Row(0, 5), Row(1, 8))) + } } From e7a0976e991f75a7bda99509e2b040daab965ae6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 17:17:27 -0700 Subject: [PATCH 186/219] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort. Author: Reynold Xin Closes #7803 from rxin/SPARK-9458 and squashes the following commits: 5b032dc [Reynold Xin] Fix string. b670dbb [Reynold Xin] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort. --- .../unsafe/sort/PrefixComparators.java | 49 ++++++++------ .../unsafe/sort/PrefixComparatorsSuite.scala | 22 ++----- .../execution/UnsafeExternalRowSorter.java | 27 ++++---- .../sql/catalyst/expressions/SortOrder.scala | 44 ++++++++++++- .../spark/sql/execution/SortPrefixUtils.scala | 64 +++---------------- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../sql/execution/joins/HashedRelation.scala | 4 +- .../org/apache/spark/sql/execution/sort.scala | 64 ++++++++----------- .../execution/RowFormatConvertersSuite.scala | 11 ++-- ...ortSuite.scala => TungstenSortSuite.scala} | 10 +-- 10 files changed, 138 insertions(+), 161 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/execution/{UnsafeExternalSortSuite.scala => TungstenSortSuite.scala} (87%) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 600aff7d15d8a..4d7e5b3dfba6e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -28,9 +28,11 @@ public class PrefixComparators { private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); - public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); - public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); public static final class StringPrefixComparator extends PrefixComparator { @Override @@ -38,50 +40,55 @@ public int compare(long aPrefix, long bPrefix) { return UnsignedLongs.compare(aPrefix, bPrefix); } - public long computePrefix(UTF8String value) { + public static long computePrefix(UTF8String value) { return value == null ? 0L : value.getPrefix(); } } - /** - * Prefix comparator for all integral types (boolean, byte, short, int, long). - */ - public static final class IntegralPrefixComparator extends PrefixComparator { + public static final class StringPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class LongPrefixComparator extends PrefixComparator { @Override public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } + } - public final long NULL_PREFIX = Long.MIN_VALUE; + public static final class LongPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } } - public static final class FloatPrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - float a = Float.intBitsToFloat((int) aPrefix); - float b = Float.intBitsToFloat((int) bPrefix); - return Utils.nanSafeCompareFloats(a, b); + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(float value) { - return Float.floatToIntBits(value) & 0xffffffffL; + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); } - public static final class DoublePrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparatorDesc extends PrefixComparator { @Override - public int compare(long aPrefix, long bPrefix) { + public int compare(long bPrefix, long aPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(double value) { + public static long computePrefix(double value) { return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index cf53a8ad21c60..26a2e96edaaa2 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -29,8 +29,8 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { def testPrefixComparison(s1: String, s2: String): Unit = { val utf8string1 = UTF8String.fromString(s1) val utf8string2 = UTF8String.fromString(s2) - val s1Prefix = PrefixComparators.STRING.computePrefix(utf8string1) - val s2Prefix = PrefixComparators.STRING.computePrefix(utf8string2) + val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) val cmp = UnsignedBytes.lexicographicalComparator().compare( @@ -55,27 +55,15 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } - test("float prefix comparator handles NaN properly") { - val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) - val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) - assert(nan1.isNaN) - assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) - val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) - assert(nan1Prefix === nan2Prefix) - val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) - assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) - } - test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) assert(nan1.isNaN) assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) - val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) - val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 4c3f2c6557140..68c49feae938e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -48,7 +48,6 @@ final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final UnsafeProjection unsafeProjection; private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; @@ -62,7 +61,6 @@ public UnsafeExternalRowSorter( PrefixComparator prefixComparator, PrefixComputer prefixComputer) throws IOException { this.schema = schema; - this.unsafeProjection = UnsafeProjection.create(schema); this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); @@ -88,13 +86,12 @@ void setTestSpillFrequency(int frequency) { } @VisibleForTesting - void insertRow(InternalRow row) throws IOException { - UnsafeRow unsafeRow = unsafeProjection.apply(row); + void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( - unsafeRow.getBaseObject(), - unsafeRow.getBaseOffset(), - unsafeRow.getSizeInBytes(), + row.getBaseObject(), + row.getBaseOffset(), + row.getSizeInBytes(), prefix ); numRowsInserted++; @@ -113,7 +110,7 @@ private void cleanupResources() { } @VisibleForTesting - Iterator sort() throws IOException { + Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -121,7 +118,7 @@ Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(); @@ -132,7 +129,7 @@ public boolean hasNext() { } @Override - public InternalRow next() { + public UnsafeRow next() { try { sortedIterator.loadNext(); row.pointTo( @@ -164,11 +161,11 @@ public InternalRow next() { } - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3f436c0eb893c..9fe877f10fa08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator abstract sealed class SortDirection case object Ascending extends SortDirection @@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection) override def nullable: Boolean = child.nullable override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + + def isAscending: Boolean = direction == Ascending +} + +/** + * An expression to generate a 64-bit long prefix used in sorting. + */ +case class SortPrefix(child: SortOrder) extends UnaryExpression { + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childCode = child.child.gen(ctx) + val input = childCode.primitive + val DoublePrefixCmp = classOf[DoublePrefixComparator].getName + + val (nullValue: Long, prefixCode: String) = child.child.dataType match { + case BooleanType => + (Long.MinValue, s"$input ? 1L : 0L") + case _: IntegralType => + (Long.MinValue, s"(long) $input") + case FloatType | DoubleType => + (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), + s"$DoublePrefixCmp.computePrefix((double)$input)") + case StringType => (0L, s"$input.getPrefix()") + case _ => (0L, "0L") + } + + childCode.code + + s""" + |long ${ev.primitive} = ${nullValue}L; + |boolean ${ev.isNull} = false; + |if (!${childCode.isNull}) { + | ${ev.primitive} = $prefixCode; + |} + """.stripMargin + } + + override def dataType: DataType = LongType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2dee3542d6101..a2145b185ce90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -37,61 +35,15 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { - case StringType => PrefixComparators.STRING - case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType => PrefixComparators.FLOAT - case DoubleType => PrefixComparators.DOUBLE + case StringType if sortOrder.isAscending => PrefixComparators.STRING + case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC + case BooleanType | ByteType | ShortType | IntegerType | LongType if sortOrder.isAscending => + PrefixComparators.LONG + case BooleanType | ByteType | ShortType | IntegerType | LongType if !sortOrder.isAscending => + PrefixComparators.LONG_DESC + case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE + case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC case _ => NoOpPrefixComparator } } - - def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { - sortOrder.dataType match { - case StringType => (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) - } - case BooleanType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 - else 0 - } - case ByteType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Byte] - } - case ShortType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Short] - } - case IntegerType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Int] - } - case LongType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Long] - } - case FloatType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX - else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) - } - case DoubleType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX - else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) - } - case _ => (row: InternalRow) => 0L - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 52a9b02d373c7..03d24a88d4ecd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -341,8 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - UnsafeExternalSort.supportsSchema(child.schema)) { - execution.UnsafeExternalSort(sortExprs, global, child) + TungstenSort.supportsSchema(child.schema)) { + execution.TungstenSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 26dbc911e9521..f88a45f48aee9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -229,7 +229,7 @@ private[joins] final class UnsafeHashedRelation( // write all the values as single byte array var totalSize = 0L var i = 0 - while (i < values.size) { + while (i < values.length) { totalSize += values(i).getSizeInBytes + 4 + 4 i += 1 } @@ -240,7 +240,7 @@ private[joins] final class UnsafeHashedRelation( out.writeInt(totalSize.toInt) out.write(key.getBytes) i = 0 - while (i < values.size) { + while (i < values.length) { // [num of fields] [num of bytes] [row bytes] // write the integer in native order, so they can be read by UNSAFE.getInt() if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index f82208868c3e3..6d903ab23c57f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Descending, BindReferences, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines various sort operators. @@ -97,59 +95,53 @@ case class ExternalSort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ -case class UnsafeExternalSort( +case class TungstenSort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) extends UnaryNode { - private[this] val schema: StructType = child.schema + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") - def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { - val ordering = newOrdering(sortOrder, child.output) - val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) - // Hack until we generate separate comparator implementations for ascending vs. descending - // (or choose to codegen them): - val prefixComparator = { - val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) - if (sortOrder.head.direction == Descending) { - new PrefixComparator { - override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) - } - } else { - comp - } - } - val prefixComputer = { - val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) - new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = prefixComputer(row) + protected override def doExecute(): RDD[InternalRow] = { + val schema = child.schema + val childOutput = child.output + child.execute().mapPartitions({ iter => + val ordering = newOrdering(sortOrder, childOutput) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) } } + val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter.sort(iterator) - } - child.execute().mapPartitions(doSort, preservesPartitioning = true) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + }, preservesPartitioning = true) } - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def outputsUnsafeRows: Boolean = true } -@DeveloperApi -object UnsafeExternalSort { +object TungstenSort { /** * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 7b75f755918c1..707cd9c6d939b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} import org.apache.spark.sql.test.TestSQLContext class RowFormatConvertersSuite extends SparkPlanTest { @@ -31,7 +30,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { @@ -41,14 +40,14 @@ class RowFormatConvertersSuite extends SparkPlanTest { } test("filter can process unsafe rows") { - val plan = Filter(IsNull(null), outputsUnsafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) + assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { - val plan = Filter(IsNull(null), outputsSafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala similarity index 87% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala index 138636b0c65b8..450963547c798 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { +class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { override def beforeAll(): Unit = { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) @@ -39,7 +39,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), sortAnswers = false ) @@ -50,7 +50,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) @@ -70,11 +70,11 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) + assert(TungstenSort.supportsSchema(inputDf.schema)) checkThatPlansAgree( inputDf, plan => ConvertToSafe( - UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), Sort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) From 0b1a464b6e061580a75b99a91b042069d76bbbfd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 30 Jul 2015 17:18:32 -0700 Subject: [PATCH 187/219] [SPARK-9425] [SQL] support DecimalType in UnsafeRow This PR brings the support of DecimalType in UnsafeRow, for precision <= 18, it's settable, otherwise it's not settable. Author: Davies Liu Closes #7758 from davies/unsafe_decimal and squashes the following commits: 478b1ba [Davies Liu] address comments 536314c [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal 7c2e77a [Davies Liu] fix JoinedRow 76d6fa4 [Davies Liu] fix tests 99d3151 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_decimal d49c6ae [Davies Liu] support DecimalType in UnsafeRow --- .../expressions/SpecializedGetters.java | 2 +- .../UnsafeFixedWidthAggregationMap.java | 22 ++-- .../sql/catalyst/expressions/UnsafeRow.java | 53 +++++--- .../expressions/UnsafeRowWriters.java | 42 +++++++ .../sql/catalyst/CatalystTypeConverters.scala | 9 +- .../spark/sql/catalyst/InternalRow.scala | 4 +- .../sql/catalyst/expressions/Projection.scala | 7 +- .../expressions/codegen/CodeGenerator.scala | 9 +- .../codegen/GenerateUnsafeProjection.scala | 115 ++++++++++-------- .../spark/sql/catalyst/expressions/rows.scala | 3 +- .../org/apache/spark/sql/types/Decimal.scala | 6 +- .../spark/sql/types/GenericArrayData.scala | 2 +- .../sql/catalyst/expressions/CastSuite.scala | 5 +- .../expressions/DateExpressionsSuite.scala | 2 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 8 +- .../expressions/UnsafeRowConverterSuite.scala | 17 +-- .../spark/sql/columnar/ColumnBuilder.scala | 2 +- .../spark/sql/columnar/ColumnStats.scala | 4 +- .../spark/sql/columnar/ColumnType.scala | 2 +- .../sql/execution/GeneratedAggregate.scala | 2 +- .../sql/execution/SparkSqlSerializer2.scala | 2 +- .../sql/parquet/ParquetTableSupport.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 40 +++++- 23 files changed, 237 insertions(+), 125 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index f7cea13688876..e3d3ba7a9ccc0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -41,7 +41,7 @@ public interface SpecializedGetters { double getDouble(int ordinal); - Decimal getDecimal(int ordinal); + Decimal getDecimal(int ordinal, int precision, int scale); UTF8String getUTF8String(int ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 03f4c3ed8e6bb..f3b462778dc10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -20,6 +20,8 @@ import java.util.Iterator; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.PlatformDependent; @@ -61,26 +63,18 @@ public final class UnsafeFixedWidthAggregationMap { private final boolean enablePerfMetrics; - /** - * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, - * false otherwise. - */ - public static boolean supportsGroupKeySchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - /** * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given * schema, false otherwise. */ public static boolean supportsAggregationBufferSchema(StructType schema) { for (StructField field: schema.fields()) { - if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { + if (field.dataType() instanceof DecimalType) { + DecimalType dt = (DecimalType) field.dataType(); + if (dt.precision() > Decimal.MAX_LONG_DIGITS()) { + return false; + } + } else if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { return false; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 6d684bac37573..e7088edced1a1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.io.OutputStream; +import java.math.BigDecimal; +import java.math.BigInteger; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; @@ -65,12 +67,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { */ public static final Set settableFieldTypes; - /** - * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). - */ - public static final Set readableFieldTypes; - - // TODO: support DecimalType + // DecimalType(precision <= 18) is settable static { settableFieldTypes = Collections.unmodifiableSet( new HashSet<>( @@ -86,16 +83,6 @@ public static int calculateBitSetWidthInBytes(int numFields) { DateType, TimestampType }))); - - // We support get() on a superset of the types for which we support set(): - final Set _readableFieldTypes = new HashSet<>( - Arrays.asList(new DataType[]{ - StringType, - BinaryType, - CalendarIntervalType - })); - _readableFieldTypes.addAll(settableFieldTypes); - readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); } ////////////////////////////////////////////////////////////////////////////// @@ -232,6 +219,21 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + assertIndexIsValid(ordinal); + if (value == null) { + setNullAt(ordinal); + } else { + if (precision <= Decimal.MAX_LONG_DIGITS()) { + setLong(ordinal, value.toUnscaledLong()); + } else { + // TODO(davies): support update decimal (hold a bounded space even it's null) + throw new UnsupportedOperationException(); + } + } + } + @Override public Object get(int ordinal) { throw new UnsupportedOperationException(); @@ -256,7 +258,8 @@ public Object get(int ordinal, DataType dataType) { } else if (dataType instanceof DoubleType) { return getDouble(ordinal); } else if (dataType instanceof DecimalType) { - return getDecimal(ordinal); + DecimalType dt = (DecimalType) dataType; + return getDecimal(ordinal, dt.precision(), dt.scale()); } else if (dataType instanceof DateType) { return getInt(ordinal); } else if (dataType instanceof TimestampType) { @@ -322,6 +325,22 @@ public double getDouble(int ordinal) { return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + assertIndexIsValid(ordinal); + if (isNullAt(ordinal)) { + return null; + } + if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(ordinal), precision, scale); + } else { + byte[] bytes = getBinary(ordinal); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + } + } + @Override public UTF8String getUTF8String(int ordinal) { assertIndexIsValid(ordinal); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index c3259e21c4a78..f43a285cd6cad 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -30,6 +31,47 @@ */ public class UnsafeRowWriters { + /** Writer for Decimal with precision under 18. */ + public static class CompactDecimalWriter { + + public static int getSize(Decimal input) { + return 0; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + target.setLong(ordinal, input.toUnscaledLong()); + return 0; + } + } + + /** Writer for Decimal with precision larger than 18. */ + public static class DecimalWriter { + + public static int getSize(Decimal input) { + // bounded size + return 16; + } + + public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { + final long offset = target.getBaseOffset() + cursor; + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + final int numBytes = bytes.length; + assert(numBytes <= 16); + + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset, 0L); + PlatformDependent.UNSAFE.putLong(target.getBaseObject(), offset + 8, 0L); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, + target.getBaseObject(), offset, numBytes); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + return 16; + } + } + /** Writer for UTF8String. */ public static class UTF8StringWriter { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 22452c0f201ef..7ca20fe97fbef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -68,7 +68,7 @@ object CatalystTypeConverters { case StringType => StringConverter case DateType => DateConverter case TimestampType => TimestampConverter - case dt: DecimalType => BigDecimalConverter + case dt: DecimalType => new DecimalConverter(dt) case BooleanType => BooleanConverter case ByteType => ByteConverter case ShortType => ShortConverter @@ -306,7 +306,8 @@ object CatalystTypeConverters { DateTimeUtils.toJavaTimestamp(row.getLong(column)) } - private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { + private class DecimalConverter(dataType: DecimalType) + extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { override def toCatalystImpl(scalaValue: Any): Decimal = scalaValue match { case d: BigDecimal => Decimal(d) case d: JavaBigDecimal => Decimal(d) @@ -314,9 +315,11 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.getDecimal(column).toJavaBigDecimal + row.getDecimal(column, dataType.precision, dataType.scale).toJavaBigDecimal } + private object BigDecimalConverter extends DecimalConverter(DecimalType.SYSTEM_DEFAULT) + private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { final override def toScala(catalystValue: Any): Any = catalystValue final override def toCatalystImpl(scalaValue: T): Any = scalaValue diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 486ba036548c8..b19bf4386b0ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -58,8 +58,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters { override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - override def getDecimal(ordinal: Int): Decimal = - getAs[Decimal](ordinal, DecimalType.SYSTEM_DEFAULT) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + getAs[Decimal](ordinal, DecimalType(precision, scale)) override def getInterval(ordinal: Int): CalendarInterval = getAs[CalendarInterval](ordinal, CalendarIntervalType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index b3beb7e28f208..7c7664e4c1a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types.{Decimal, StructType, DataType} import org.apache.spark.unsafe.types.UTF8String /** @@ -225,6 +225,11 @@ class JoinedRow extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + if (i < row1.numFields) row1.getDecimal(i, precision, scale) + else row2.getDecimal(i - row1.numFields, precision, scale) + } + override def getStruct(i: Int, numFields: Int): InternalRow = { if (i < row1.numFields) { row1.getStruct(i, numFields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index c39e0df6fae2a..60e2863f7bbb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -106,6 +106,7 @@ class CodeGenContext { val jt = javaType(dataType) dataType match { case _ if isPrimitiveType(jt) => s"$getter.get${primitiveTypeName(jt)}($ordinal)" + case t: DecimalType => s"$getter.getDecimal($ordinal, ${t.precision}, ${t.scale})" case StringType => s"$getter.getUTF8String($ordinal)" case BinaryType => s"$getter.getBinary($ordinal)" case CalendarIntervalType => s"$getter.getInterval($ordinal)" @@ -120,10 +121,10 @@ class CodeGenContext { */ def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { val jt = javaType(dataType) - if (isPrimitiveType(jt)) { - s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" - } else { - s"$row.update($ordinal, $value)" + dataType match { + case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" + case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" + case _ => s"$row.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index a662357fb6cf9..1d223986d9441 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -35,6 +35,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName + private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName + private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { @@ -42,9 +44,64 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true + case t: DecimalType => true case _ => false } + def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s" + (${ev.isNull} ? 0 : $DecimalWriter.getSize(${ev.primitive}))" + case StringType => + s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" + case BinaryType => + s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" + case CalendarIntervalType => + s" + (${ev.isNull} ? 0 : 16)" + case _: StructType => + s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case _ => "" + } + + def genFieldWriter( + ctx: CodeGenContext, + fieldType: DataType, + ev: GeneratedExpressionCode, + primitive: String, + index: Int, + cursor: String): String = fieldType match { + case _ if ctx.isPrimitiveType(fieldType) => + s"${ctx.setColumn(primitive, fieldType, index, ev.primitive)}" + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $CompactDecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => + s""" + // make sure Decimal object has the same scale as DecimalType + if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { + $cursor += $DecimalWriter.write($primitive, $index, $cursor, ${ev.primitive}); + } else { + $primitive.setNullAt($index); + } + """ + case StringType => + s"$cursor += $StringWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case BinaryType => + s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case CalendarIntervalType => + s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case t: StructType => + s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case NullType => "" + case _ => + throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") + } + /** * Generates the code to create an [[UnsafeRow]] object based on the input expressions. * @param ctx context for code generation @@ -69,36 +126,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val allExprs = exprs.map(_.code).mkString("\n") val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = expressions.zipWithIndex.map { case (e, i) => - e.dataType match { - case StringType => - s" + (${exprs(i).isNull} ? 0 : $StringWriter.getSize(${exprs(i).primitive}))" - case BinaryType => - s" + (${exprs(i).isNull} ? 0 : $BinaryWriter.getSize(${exprs(i).primitive}))" - case CalendarIntervalType => - s" + (${exprs(i).isNull} ? 0 : 16)" - case _: StructType => - s" + (${exprs(i).isNull} ? 0 : $StructWriter.getSize(${exprs(i).primitive}))" - case _ => "" - } + val additionalSize = expressions.zipWithIndex.map { + case (e, i) => genAdditionalSize(e.dataType, exprs(i)) }.mkString("") val writers = expressions.zipWithIndex.map { case (e, i) => - val update = e.dataType match { - case dt if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn(ret, dt, i, exprs(i).primitive)}" - case StringType => - s"$cursor += $StringWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case BinaryType => - s"$cursor += $BinaryWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case t: StructType => - s"$cursor += $StructWriter.write($ret, $i, $cursor, ${exprs(i).primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: ${e.dataType}") - } + val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor) s"""if (${exprs(i).isNull}) { $ret.setNullAt($i); } else { @@ -168,35 +201,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => - dt match { - case StringType => - s" + (${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive}))" - case BinaryType => - s" + (${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive}))" - case CalendarIntervalType => - s" + (${ev.isNull} ? 0 : 16)" - case _: StructType => - s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" - case _ => "" - } + genAdditionalSize(dt, ev) }.mkString("") val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => - val update = dt match { - case _ if ctx.isPrimitiveType(dt) => - s"${ctx.setColumn(primitive, dt, i, exprs(i).primitive)}" - case StringType => - s"$cursor += $StringWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case BinaryType => - s"$cursor += $BinaryWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case t: StructType => - s"$cursor += $StructWriter.write($primitive, $i, $cursor, ${exprs(i).primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: $dt") - } + val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor) s""" if (${exprs(i).isNull}) { $primitive.setNullAt($i); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index b7c4ece4a16fe..df6ea586c87ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, StructType, AtomicType} +import org.apache.spark.sql.types.{Decimal, DataType, StructType, AtomicType} import org.apache.spark.unsafe.types.UTF8String /** @@ -39,6 +39,7 @@ abstract class MutableRow extends InternalRow { def setShort(i: Int, value: Short): Unit = { update(i, value) } def setByte(i: Int, value: Byte): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } def setString(i: Int, value: String): Unit = { update(i, UTF8String.fromString(value)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index bc689810bc292..c0155eeb450a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -188,6 +188,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { * @return true if successful, false if overflow would occur */ def changePrecision(precision: Int, scale: Int): Boolean = { + // fast path for UnsafeProjection + if (precision == this.precision && scale == this.scale) { + return true + } // First, update our longVal if we can, or transfer over to using a BigDecimal if (decimalVal.eq(null)) { if (scale < _scale) { @@ -224,7 +228,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { decimalVal = newVal } else { // We're still using Longs, but we should check whether we match the new precision - val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) if (longVal <= -p || longVal >= p) { // Note that we shouldn't have been able to fix this by switching to BigDecimal return false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index 7992ba947c069..35ace673fb3da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -43,7 +43,7 @@ class GenericArrayData(array: Array[Any]) extends ArrayData { override def getDouble(ordinal: Int): Double = getAs(ordinal) - override def getDecimal(ordinal: Int): Decimal = getAs(ordinal) + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 4f35b653d73c0..1ad70733eae03 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -242,10 +242,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123)) checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123)) - checkEvaluation(cast(123L, DecimalType(3, 1)), Decimal(123.0)) + checkEvaluation(cast(123L, DecimalType(3, 1)), null) - // TODO: Fix the following bug and re-enable it. - // checkEvaluation(cast(123L, DecimalType(2, 0)), null) + checkEvaluation(cast(123L, DecimalType(2, 0)), null) } test("cast from boolean") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index fd1d6c1d25497..887e43621a941 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Calendar diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 6a907290f2dbe..c6b4c729de2f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -55,13 +55,13 @@ class UnsafeFixedWidthAggregationMapSuite } test("supported schemas") { + assert(supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) + assert(!supportsAggregationBufferSchema( + StructType(StructField("x", DecimalType.SYSTEM_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) - assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) - assert( !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - assert( - !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) } test("empty map") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index b7bc17f89e82f..a0e1701339ea7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -46,7 +46,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(unsafeRow.getLong(1) === 1) assert(unsafeRow.getInt(2) === 2) - // We can copy UnsafeRows as long as they don't reference ObjectPools val unsafeRowCopy = unsafeRow.copy() assert(unsafeRowCopy.getLong(0) === 0) assert(unsafeRowCopy.getLong(1) === 1) @@ -122,8 +121,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { FloatType, DoubleType, StringType, - BinaryType - // DecimalType.Default, + BinaryType, + DecimalType.USER_DEFAULT // ArrayType(IntegerType) ) val converter = UnsafeProjection.create(fieldTypes) @@ -150,7 +149,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getDouble(7) === 0.0d) assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) - // assert(createdFromNull.get(10) === null) + assert(createdFromNull.getDecimal(10, 10, 0) === null) // assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those @@ -168,7 +167,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setDouble(7, 700) r.update(8, UTF8String.fromString("hello")) r.update(9, "world".getBytes) - // r.update(10, Decimal(10)) + r.setDecimal(10, Decimal(10), 10) // r.update(11, Array(11)) r } @@ -184,7 +183,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) assert(setToNullAfterCreation.getBinary(9) === rowWithNoNullColumns.getBinary(9)) - // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- fieldTypes.indices) { @@ -203,7 +203,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { setToNullAfterCreation.setDouble(7, 700) // setToNullAfterCreation.update(8, UTF8String.fromString("hello")) // setToNullAfterCreation.update(9, "world".getBytes) - // setToNullAfterCreation.update(10, Decimal(10)) + setToNullAfterCreation.setDecimal(10, Decimal(10), 10) // setToNullAfterCreation.update(11, Array(11)) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) @@ -216,7 +216,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) // assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) // assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) - // assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.getDecimal(10, 10, 0) === + rowWithNoNullColumns.getDecimal(10, 10, 0)) // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 454b7b91a63f5..1620fc401ba6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -114,7 +114,7 @@ private[sql] class FixedDecimalColumnBuilder( precision: Int, scale: Int) extends NativeColumnBuilder( - new FixedDecimalColumnStats, + new FixedDecimalColumnStats(precision, scale), FIXED_DECIMAL(precision, scale)) // TODO (lian) Add support for array, struct and map diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 32a84b2676e07..af1a8ecca9b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -234,14 +234,14 @@ private[sql] class BinaryColumnStats extends ColumnStats { InternalRow(null, null, nullCount, count, sizeInBytes) } -private[sql] class FixedDecimalColumnStats extends ColumnStats { +private[sql] class FixedDecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { protected var upper: Decimal = null protected var lower: Decimal = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getDecimal(ordinal) + val value = row.getDecimal(ordinal, precision, scale) if (upper == null || value.compareTo(upper) > 0) upper = value if (lower == null || value.compareTo(lower) < 0) lower = value sizeInBytes += FIXED_DECIMAL.defaultSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 2863f6c230a9d..30f8fe320db3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -392,7 +392,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) } override def getField(row: InternalRow, ordinal: Int): Decimal = { - row.getDecimal(ordinal) + row.getDecimal(ordinal, precision, scale) } override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b85aada9d9d4c..d851eae3fcc71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -202,7 +202,7 @@ case class GeneratedAggregate( val schemaSupportsUnsafe: Boolean = { UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) + UnsafeProjection.canSupport(groupKeySchema) } child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index c808442a4849b..e5bbd0aaed0a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -298,7 +298,7 @@ private[sql] object SparkSqlSerializer2 { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.getDecimal(i) + val value = row.getDecimal(i, decimal.precision, decimal.scale) val javaBigDecimal = value.toJavaBigDecimal // First, write out the unscaled value. val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 79dd16b7b0c39..ec8da38a3d427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -293,8 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) case BinaryType => writer.addBinary(Binary.fromByteArray(record.getBinary(index))) - case DecimalType.Fixed(precision, _) => - writeDecimal(record.getDecimal(index), precision) + case DecimalType.Fixed(precision, scale) => + writeDecimal(record.getDecimal(index, precision, scale), precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 4499a7207031d..66014ddca0596 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -34,8 +34,7 @@ class ColumnStatsSuite extends SparkFunSuite { testColumnStats(classOf[DoubleColumnStats], DOUBLE, InternalRow(Double.MaxValue, Double.MinValue, 0)) testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testColumnStats(classOf[FixedDecimalColumnStats], - FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) + testDecimalColumnStats(InternalRow(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], @@ -52,7 +51,7 @@ class ColumnStatsSuite extends SparkFunSuite { } test(s"$columnStatsName: non-empty") { - import ColumnarTestUtils._ + import org.apache.spark.sql.columnar.ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) @@ -73,4 +72,39 @@ class ColumnStatsSuite extends SparkFunSuite { } } } + + def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats](initialStatistics: InternalRow) { + + val columnStatsName = classOf[FixedDecimalColumnStats].getSimpleName + val columnType = FIXED_DECIMAL(15, 10) + + test(s"$columnStatsName: empty") { + val columnStats = new FixedDecimalColumnStats(15, 10) + columnStats.collectedStatistics.toSeq.zip(initialStatistics.toSeq).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"$columnStatsName: non-empty") { + import org.apache.spark.sql.columnar.ColumnarTestUtils._ + + val columnStats = new FixedDecimalColumnStats(15, 10) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.take(10).map(_.get(0, columnType.dataType).asInstanceOf[T#InternalType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] + val stats = columnStats.collectedStatistics + + assertResult(values.min(ordering), "Wrong lower bound")(stats.genericGet(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats.genericGet(1)) + assertResult(10, "Wrong null count")(stats.genericGet(2)) + assertResult(20, "Wrong row count")(stats.genericGet(3)) + assertResult(stats.genericGet(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } } From 351eda0e2fd47c183c4298469970032097ad07a0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 17:22:51 -0700 Subject: [PATCH 188/219] [SPARK-6319][SQL] Throw AnalysisException when using BinaryType on Join and Aggregate JIRA: https://issues.apache.org/jira/browse/SPARK-6319 Spark SQL uses plain byte arrays to represent binary values. However, the arrays are compared by reference rather than by values. Thus, we should not use BinaryType on Join and Aggregate in current implementation. Author: Liang-Chi Hsieh Closes #7787 from viirya/agg_no_binary_type and squashes the following commits: 4f76cac [Liang-Chi Hsieh] Throw AnalysisException when using BinaryType on Join and Aggregate. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 20 +++++++++++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 11 +++++++++- .../org/apache/spark/sql/JoinSuite.scala | 9 +++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a373714832962..0ebc3d180a780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -87,6 +87,18 @@ trait CheckAnalysis { s"join condition '${condition.prettyString}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, _, Some(condition)) => + def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { + case p: Predicate => + p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) + case e if e.dataType.isInstanceOf[BinaryType] => + failAnalysis(s"expression ${e.prettyString} in join condition " + + s"'${condition.prettyString}' can't be binary type.") + case _ => // OK + } + + checkValidJoinConditionExprs(condition) + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK @@ -100,7 +112,15 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } + def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { + case BinaryType => + failAnalysis(s"grouping expression '${expr.prettyString}' in aggregate can " + + s"not be binary type.") + case _ => // OK + } + aggregateExprs.foreach(checkValidAggregateExpression) + aggregateExprs.foreach(checkValidGroupingExprs) case Sort(orders, _, _) => orders.foreach { order => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b26d3ab253a1d..228ece8065151 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{BinaryType, DecimalType} class DataFrameAggregateSuite extends QueryTest { @@ -191,4 +191,13 @@ class DataFrameAggregateSuite extends QueryTest { Row(null)) } + test("aggregation can't work on binary type") { + val df = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + intercept[AnalysisException] { + df.groupBy("c").agg(count("*")) + } + intercept[AnalysisException] { + df.distinct + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 666f26bf620e1..27c08f64649ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.types.BinaryType class JoinSuite extends QueryTest with BeforeAndAfterEach { @@ -489,4 +490,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(3, 2) :: Nil) } + + test("Join can't work on binary type") { + val left = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("c").select($"c" cast BinaryType) + val right = Seq(1, 1, 2, 2).map(i => Tuple1(i.toString)).toDF("d").select($"d" cast BinaryType) + intercept[AnalysisException] { + left.join(right, ($"left.N" === $"right.N"), "full") + } + } } From 65fa4181c35135080870c1e4c1f904ada3a8cf59 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 30 Jul 2015 17:26:18 -0700 Subject: [PATCH 189/219] [SPARK-9077] [MLLIB] Improve error message for decision trees when numExamples < maxCategoriesPerFeature Improve error message when number of examples is less than arity of high-arity categorical feature CC jkbradley is this about what you had in mind? I know it's a starter, but was on my list to close out in the short term. Author: Sean Owen Closes #7800 from srowen/SPARK-9077 and squashes the following commits: b8f6cdb [Sean Owen] Improve error message when number of examples is less than arity of high-arity categorical feature --- .../spark/mllib/tree/impl/DecisionTreeMetadata.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 380291ac22bd3..9fe264656ede7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -128,9 +128,13 @@ private[spark] object DecisionTreeMetadata extends Logging { // based on the number of training examples. if (strategy.categoricalFeaturesInfo.nonEmpty) { val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 require(maxCategoriesPerFeature <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + - s"in categorical features (= $maxCategoriesPerFeature)") + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") } val unorderedFeatures = new mutable.HashSet[Int]() From 3c66ff727d4b47220e1ff363cea215189ed64f36 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 30 Jul 2015 17:38:48 -0700 Subject: [PATCH 190/219] [SPARK-9489] Remove unnecessary compatibility and requirements checks from Exchange While reviewing yhuai's patch for SPARK-2205 (#7773), I noticed that Exchange's `compatible` check may be incorrectly returning `false` in many cases. As far as I know, this is not actually a problem because the `compatible`, `meetsRequirements`, and `needsAnySort` checks are serving only as short-circuit performance optimizations that are not necessary for correctness. In order to reduce code complexity, I think that we should remove these checks and unconditionally rewrite the operator's children. This should be safe because we rewrite the tree in a single bottom-up pass. Author: Josh Rosen Closes #7807 from JoshRosen/SPARK-9489 and squashes the following commits: 9d76ce9 [Josh Rosen] [SPARK-9489] Remove compatibleWith, meetsRequirements, and needsAnySort checks from Exchange --- .../plans/physical/partitioning.scala | 35 --------- .../apache/spark/sql/execution/Exchange.scala | 76 +++++-------------- 2 files changed, 17 insertions(+), 94 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2dcfa19fec383..f4d1dbaf28efe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -86,14 +86,6 @@ sealed trait Partitioning { */ def satisfies(required: Distribution): Boolean - /** - * Returns true iff all distribution guarantees made by this partitioning can also be made - * for the `other` specified partitioning. - * For example, two [[HashPartitioning HashPartitioning]]s are - * only compatible if the `numPartitions` of them is the same. - */ - def compatibleWith(other: Partitioning): Boolean - /** Returns the expressions that are used to key the partitioning. */ def keyExpressions: Seq[Expression] } @@ -104,11 +96,6 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case UnknownPartitioning(_) => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -117,11 +104,6 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -130,11 +112,6 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def compatibleWith(other: Partitioning): Boolean = other match { - case SinglePartition => true - case _ => false - } - override def keyExpressions: Seq[Expression] = Nil } @@ -159,12 +136,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case h: HashPartitioning if h == this => true - case _ => false - } - override def keyExpressions: Seq[Expression] = expressions } @@ -199,11 +170,5 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def compatibleWith(other: Partitioning): Boolean = other match { - case BroadcastPartitioning => true - case r: RangePartitioning if r == this => true - case _ => false - } - override def keyExpressions: Seq[Expression] = ordering.map(_.child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 70e5031fb63c0..6bd57f010a990 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -202,41 +202,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => - // True iff every child's outputPartitioning satisfies the corresponding - // required data distribution. - def meetsRequirements: Boolean = - operator.requiredChildDistribution.zip(operator.children).forall { - case (required, child) => - val valid = child.outputPartitioning.satisfies(required) - logDebug( - s"${if (valid) "Valid" else "Invalid"} distribution," + - s"required: $required current: ${child.outputPartitioning}") - valid - } - - // True iff any of the children are incorrectly sorted. - def needsAnySort: Boolean = - operator.requiredChildOrdering.zip(operator.children).exists { - case (required, child) => required.nonEmpty && required != child.outputOrdering - } - - // True iff outputPartitionings of children are compatible with each other. - // It is possible that every child satisfies its required data distribution - // but two children have incompatible outputPartitionings. For example, - // A dataset is range partitioned by "a.asc" (RangePartitioning) and another - // dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two - // datasets are both clustered by "a", but these two outputPartitionings are not - // compatible. - // TODO: ASSUMES TRANSITIVITY? - def compatible: Boolean = - operator.children - .map(_.outputPartitioning) - .sliding(2) - .forall { - case Seq(a) => true - case Seq(a, b) => a.compatibleWith(b) - } - // Adds Exchange or Sort operators as required def addOperatorsIfNecessary( partitioning: Partitioning, @@ -269,33 +234,26 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ addSortIfNecessary(addShuffleIfNecessary(child)) } - if (meetsRequirements && compatible && !needsAnySort) { - operator - } else { - // At least one child does not satisfies its required data distribution or - // at least one child's outputPartitioning is not compatible with another child's - // outputPartitioning. In this case, we need to add Exchange operators. - val requirements = - (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) + val requirements = + (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children) - val fixedChildren = requirements.zipped.map { - case (AllTuples, rowOrdering, child) => - addOperatorsIfNecessary(SinglePartition, rowOrdering, child) - case (ClusteredDistribution(clustering), rowOrdering, child) => - addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) - case (OrderedDistribution(ordering), rowOrdering, child) => - addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) + val fixedChildren = requirements.zipped.map { + case (AllTuples, rowOrdering, child) => + addOperatorsIfNecessary(SinglePartition, rowOrdering, child) + case (ClusteredDistribution(clustering), rowOrdering, child) => + addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child) + case (OrderedDistribution(ordering), rowOrdering, child) => + addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child) - case (UnspecifiedDistribution, Seq(), child) => - child - case (UnspecifiedDistribution, rowOrdering, child) => - sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) + case (UnspecifiedDistribution, Seq(), child) => + child + case (UnspecifiedDistribution, rowOrdering, child) => + sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child) - case (dist, ordering, _) => - sys.error(s"Don't know how to ensure $dist with ordering $ordering") - } - - operator.withNewChildren(fixedChildren) + case (dist, ordering, _) => + sys.error(s"Don't know how to ensure $dist with ordering $ordering") } + + operator.withNewChildren(fixedChildren) } } From 9307f5653d19a6a2fda355a675ca9ea97e35611b Mon Sep 17 00:00:00 2001 From: cody koeninger Date: Thu, 30 Jul 2015 17:44:20 -0700 Subject: [PATCH 191/219] [SPARK-9472] [STREAMING] consistent hadoop configuration, streaming only Author: cody koeninger Closes #7772 from koeninger/streaming-hadoop-config and squashes the following commits: 5267284 [cody koeninger] [SPARK-4229][Streaming] consistent hadoop configuration, streaming only --- .../main/scala/org/apache/spark/streaming/Checkpoint.scala | 3 ++- .../org/apache/spark/streaming/StreamingContext.scala | 7 ++++--- .../apache/spark/streaming/api/java/JavaPairDStream.scala | 2 +- .../spark/streaming/api/java/JavaStreamingContext.scala | 3 ++- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 65d4e933bf8e9..2780d5b6adbcf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{MetadataCleaner, Utils} import org.apache.spark.streaming.scheduler.JobGenerator @@ -100,7 +101,7 @@ object Checkpoint extends Logging { } val path = new Path(checkpointDir) - val fs = fsOption.getOrElse(path.getFileSystem(new Configuration())) + val fs = fsOption.getOrElse(path.getFileSystem(SparkHadoopUtil.get.conf)) if (fs.exists(path)) { val statuses = fs.listStatus(path) if (statuses != null) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 92438f1b1fbf7..177e710ace54b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.input.FixedLengthBinaryInputFormat import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.serializer.SerializationDebugger @@ -110,7 +111,7 @@ class StreamingContext private[streaming] ( * Recreate a StreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(path, new Configuration) + def this(path: String) = this(path, SparkHadoopUtil.get.conf) /** * Recreate a StreamingContext from a checkpoint file using an existing SparkContext. @@ -803,7 +804,7 @@ object StreamingContext extends Logging { def getActiveOrCreate( checkpointPath: String, creatingFunc: () => StreamingContext, - hadoopConf: Configuration = new Configuration(), + hadoopConf: Configuration = SparkHadoopUtil.get.conf, createOnError: Boolean = false ): StreamingContext = { ACTIVATION_LOCK.synchronized { @@ -828,7 +829,7 @@ object StreamingContext extends Logging { def getOrCreate( checkpointPath: String, creatingFunc: () => StreamingContext, - hadoopConf: Configuration = new Configuration(), + hadoopConf: Configuration = SparkHadoopUtil.get.conf, createOnError: Boolean = false ): StreamingContext = { val checkpointOption = CheckpointReader.read( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 959ac9c177f81..26383e420101e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -788,7 +788,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[F], - conf: Configuration = new Configuration) { + conf: Configuration = dstream.context.sparkContext.hadoopConfiguration) { dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 40deb6d7ea79a..35cc3ce5cf468 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -33,6 +33,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} import org.apache.spark.api.java.function.{Function0 => JFunction0} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ @@ -136,7 +137,7 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Recreate a JavaStreamingContext from a checkpoint file. * @param path Path to the directory that was specified as the checkpoint directory */ - def this(path: String) = this(new StreamingContext(path, new Configuration)) + def this(path: String) = this(new StreamingContext(path, SparkHadoopUtil.get.conf)) /** * Re-creates a JavaStreamingContext from a checkpoint file. From 83670fc9e6fc9c7a6ae68dfdd3f9335ea72f4ab0 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Thu, 30 Jul 2015 19:22:38 -0700 Subject: [PATCH 192/219] [SPARK-8176] [SPARK-8197] [SQL] function to_date/ trunc This PR is based on #6988 , thanks to adrian-wang . This brings two SQL functions: to_date() and trunc(). Closes #6988 Author: Daoyuan Wang Author: Davies Liu Closes #7805 from davies/to_date and squashes the following commits: 2c7beba [Davies Liu] Merge branch 'master' of github.com:apache/spark into to_date 310dd55 [Daoyuan Wang] remove dup test in rebase 980b092 [Daoyuan Wang] resolve rebase conflict a476c5a [Daoyuan Wang] address comments from davies d44ea5f [Daoyuan Wang] function to_date, trunc --- python/pyspark/sql/functions.py | 30 +++++++ .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/datetimeFunctions.scala | 88 ++++++++++++++++++- .../sql/catalyst/util/DateTimeUtils.scala | 34 +++++++ .../expressions/DateExpressionsSuite.scala | 29 +++++- .../expressions/NonFoldableLiteral.scala | 4 + .../org/apache/spark/sql/functions.scala | 16 ++++ .../apache/spark/sql/DateFunctionsSuite.scala | 44 ++++++++++ 8 files changed, 245 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a7295e25f0aa5..8024a8de07c98 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -888,6 +888,36 @@ def months_between(date1, date2): return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2))) +@since(1.5) +def to_date(col): + """ + Converts the column of StringType or TimestampType into DateType. + + >>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00',)], ['t']) + >>> df.select(to_date(df.t).alias('date')).collect() + [Row(date=datetime.date(1997, 2, 28))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.to_date(_to_java_column(col))) + + +@since(1.5) +def trunc(date, format): + """ + Returns date truncated to the unit specified by the format. + + :param format: 'year', 'YYYY', 'yy' or 'month', 'mon', 'mm' + + >>> df = sqlContext.createDataFrame([('1997-02-28',)], ['d']) + >>> df.select(trunc(df.d, 'year').alias('year')).collect() + [Row(year=datetime.date(1997, 1, 1))] + >>> df.select(trunc(df.d, 'mon').alias('month')).collect() + [Row(month=datetime.date(1997, 2, 1))] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.trunc(_to_java_column(date), format)) + + @since(1.5) def size(col): """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6c7c481fab8db..1bf7204a2515c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -223,6 +223,8 @@ object FunctionRegistry { expression[NextDay]("next_day"), expression[Quarter]("quarter"), expression[Second]("second"), + expression[ToDate]("to_date"), + expression[TruncDate]("trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 9795673ee0664..6e7613340c032 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -507,7 +507,6 @@ case class FromUnixTime(sec: Expression, format: Expression) }) } } - } /** @@ -696,3 +695,90 @@ case class MonthsBetween(date1: Expression, date2: Expression) }) } } + +/** + * Returns the date part of a timestamp or string. + */ +case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + // Implicit casting of spark will accept string in both date and timestamp format, as + // well as TimestampType. + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = DateType + + override def eval(input: InternalRow): Any = child.eval(input) + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, d => d) + } +} + +/* + * Returns date truncated to the unit specified by the format. + */ +case class TruncDate(date: Expression, format: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + override def left: Expression = date + override def right: Expression = format + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) + override def dataType: DataType = DateType + override def prettyName: String = "trunc" + + lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + + override def eval(input: InternalRow): Any = { + val minItem = if (format.foldable) { + minItemConst + } else { + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + } + if (minItem == -1) { + // unknown format + null + } else { + val d = date.eval(input) + if (d == null) { + null + } else { + DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + + if (format.foldable) { + if (minItemConst == -1) { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } else { + val d = date.gen(ctx) + s""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst); + } + """ + } + } else { + nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { + val form = ctx.freshName("form") + s""" + int $form = $dtu.parseTruncLevel($fmt); + if ($form == -1) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $dtu.truncDate($dateVal, $form); + } + """ + }) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 53abdf6618eac..5a7c25b8d508d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -779,4 +779,38 @@ object DateTimeUtils { } date + (lastDayOfMonthInYear - dayInYear) } + + private val TRUNC_TO_YEAR = 1 + private val TRUNC_TO_MONTH = 2 + private val TRUNC_INVALID = -1 + + /** + * Returns the trunc date from original date and trunc level. + * Trunc level should be generated using `parseTruncLevel()`, should only be 1 or 2. + */ + def truncDate(d: Int, level: Int): Int = { + if (level == TRUNC_TO_YEAR) { + d - DateTimeUtils.getDayInYear(d) + 1 + } else if (level == TRUNC_TO_MONTH) { + d - DateTimeUtils.getDayOfMonth(d) + 1 + } else { + throw new Exception(s"Invalid trunc level: $level") + } + } + + /** + * Returns the truncate level, could be TRUNC_YEAR, TRUNC_MONTH, or TRUNC_INVALID, + * TRUNC_INVALID means unsupported truncate level. + */ + def parseTruncLevel(format: UTF8String): Int = { + if (format == null) { + TRUNC_INVALID + } else { + format.toString.toUpperCase match { + case "YEAR" | "YYYY" | "YY" => TRUNC_TO_YEAR + case "MON" | "MONTH" | "MM" => TRUNC_TO_MONTH + case _ => TRUNC_INVALID + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 887e43621a941..6c15c05da3094 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -351,6 +351,34 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { NextDay(Literal(Date.valueOf("2015-07-23")), Literal.create(null, StringType)), null) } + test("function to_date") { + checkEvaluation( + ToDate(Literal(Date.valueOf("2015-07-22"))), + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-22"))) + checkEvaluation(ToDate(Literal.create(null, DateType)), null) + } + + test("function trunc") { + def testTrunc(input: Date, fmt: String, expected: Date): Unit = { + checkEvaluation(TruncDate(Literal.create(input, DateType), Literal.create(fmt, StringType)), + expected) + checkEvaluation( + TruncDate(Literal.create(input, DateType), NonFoldableLiteral.create(fmt, StringType)), + expected) + } + val date = Date.valueOf("2015-07-22") + Seq("yyyy", "YYYY", "year", "YEAR", "yy", "YY").foreach{ fmt => + testTrunc(date, fmt, Date.valueOf("2015-01-01")) + } + Seq("month", "MONTH", "mon", "MON", "mm", "MM").foreach { fmt => + testTrunc(date, fmt, Date.valueOf("2015-07-01")) + } + testTrunc(date, "DD", null) + testTrunc(date, null, null) + testTrunc(null, "MON", null) + testTrunc(null, null, null) + } + test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" @@ -405,5 +433,4 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format")), null) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index 0559fb80e7fce..31ecf4a9e810a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -47,4 +47,8 @@ object NonFoldableLiteral { val lit = Literal(value) NonFoldableLiteral(lit.value, lit.dataType) } + def create(value: Any, dataType: DataType): NonFoldableLiteral = { + val lit = Literal.create(value, dataType) + NonFoldableLiteral(lit.value, lit.dataType) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 168894d66117d..46dc4605a5ccb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2181,6 +2181,22 @@ object functions { */ def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + /* + * Converts the column into DateType. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def to_date(e: Column): Column = ToDate(e.expr) + + /** + * Returns date truncated to the unit specified by the format. + * + * @group datetime_funcs + * @since 1.5.0 + */ + def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index b7267c413165a..8c596fad74ee4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -345,6 +345,50 @@ class DateFunctionsSuite extends QueryTest { Seq(Row(Date.valueOf("2015-07-30")), Row(Date.valueOf("2015-07-30")))) } + test("function to_date") { + val d1 = Date.valueOf("2015-07-22") + val d2 = Date.valueOf("2015-07-01") + val t1 = Timestamp.valueOf("2015-07-22 10:00:00") + val t2 = Timestamp.valueOf("2014-12-31 23:59:59") + val s1 = "2015-07-22 10:00:00" + val s2 = "2014-12-31" + val df = Seq((d1, t1, s1), (d2, t2, s2)).toDF("d", "t", "s") + + checkAnswer( + df.select(to_date(col("t"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.select(to_date(col("d"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.select(to_date(col("s"))), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + + checkAnswer( + df.selectExpr("to_date(t)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + checkAnswer( + df.selectExpr("to_date(d)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2015-07-01")))) + checkAnswer( + df.selectExpr("to_date(s)"), + Seq(Row(Date.valueOf("2015-07-22")), Row(Date.valueOf("2014-12-31")))) + } + + test("function trunc") { + val df = Seq( + (1, Timestamp.valueOf("2015-07-22 10:00:00")), + (2, Timestamp.valueOf("2014-12-31 00:00:00"))).toDF("i", "t") + + checkAnswer( + df.select(trunc(col("t"), "YY")), + Seq(Row(Date.valueOf("2015-01-01")), Row(Date.valueOf("2014-01-01")))) + + checkAnswer( + df.selectExpr("trunc(t, 'Month')"), + Seq(Row(Date.valueOf("2015-07-01")), Row(Date.valueOf("2014-12-01")))) + } + test("from_unixtime") { val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" From 4e5919bfb47a58bcbda90ae01c1bed2128ded983 Mon Sep 17 00:00:00 2001 From: Ram Sriharsha Date: Thu, 30 Jul 2015 23:02:11 -0700 Subject: [PATCH 193/219] [SPARK-7690] [ML] Multiclass classification Evaluator Multiclass Classification Evaluator for ML Pipelines. F1 score, precision, recall, weighted precision and weighted recall are supported as available metrics. Author: Ram Sriharsha Closes #7475 from harsha2010/SPARK-7690 and squashes the following commits: 9bf4ec7 [Ram Sriharsha] fix indentation 3f09a85 [Ram Sriharsha] cleanup doc 16115ae [Ram Sriharsha] code review fixes 032d2a3 [Ram Sriharsha] fix test eec9865 [Ram Sriharsha] Fix Python Indentation 1dbeffd [Ram Sriharsha] Merge branch 'master' into SPARK-7690 68cea85 [Ram Sriharsha] Merge branch 'master' into SPARK-7690 54c03de [Ram Sriharsha] [SPARK-7690][ml][WIP] Multiclass Evaluator for ML Pipeline --- .../MulticlassClassificationEvaluator.scala | 85 +++++++++++++++++++ ...lticlassClassificationEvaluatorSuite.scala | 28 ++++++ python/pyspark/ml/evaluation.py | 66 ++++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala new file mode 100644 index 0000000000000..44f779c1908d7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} +import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.types.DoubleType + +/** + * :: Experimental :: + * Evaluator for multiclass classification, which expects two input columns: score and label. + */ +@Experimental +class MulticlassClassificationEvaluator (override val uid: String) + extends Evaluator with HasPredictionCol with HasLabelCol { + + def this() = this(Identifiable.randomUID("mcEval")) + + /** + * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, + * `"weightedPrecision"`, `"weightedRecall"`) + * @group param + */ + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("f1", "precision", + "recall", "weightedPrecision", "weightedRecall")) + new Param(this, "metricName", "metric name in evaluation " + + "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams) + } + + /** @group getParam */ + def getMetricName: String = $(metricName) + + /** @group setParam */ + def setMetricName(value: String): this.type = set(metricName, value) + + /** @group setParam */ + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + def setLabelCol(value: String): this.type = set(labelCol, value) + + setDefault(metricName -> "f1") + + override def evaluate(dataset: DataFrame): Double = { + val schema = dataset.schema + SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) + SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + + val predictionAndLabels = dataset.select($(predictionCol), $(labelCol)) + .map { case Row(prediction: Double, label: Double) => + (prediction, label) + } + val metrics = new MulticlassMetrics(predictionAndLabels) + val metric = $(metricName) match { + case "f1" => metrics.weightedFMeasure + case "precision" => metrics.precision + case "recall" => metrics.recall + case "weightedPrecision" => metrics.weightedPrecision + case "weightedRecall" => metrics.weightedRecall + } + metric + } + + override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala new file mode 100644 index 0000000000000..6d8412b0b3701 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite + +class MulticlassClassificationEvaluatorSuite extends SparkFunSuite { + + test("params") { + ParamsSuite.checkParams(new MulticlassClassificationEvaluator) + } +} diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 595593a7f2cde..06e809352225b 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -214,6 +214,72 @@ def setParams(self, predictionCol="prediction", labelCol="label", kwargs = self.setParams._input_kwargs return self._set(**kwargs) + +@inherit_doc +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): + """ + Evaluator for Multiclass Classification, which expects two input + columns: prediction and label. + >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), + ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)] + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"]) + ... + >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction") + >>> evaluator.evaluate(dataset) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"}) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) + 0.66... + """ + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation " + "(f1|precision|recall|weightedPrecision|weightedRecall)") + + @keyword_only + def __init__(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + __init__(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + """ + super(MulticlassClassificationEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) + # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall) + self.metricName = Param(self, "metricName", + "metric name in evaluation" + " (f1|precision|recall|weightedPrecision|weightedRecall)") + self._setDefault(predictionCol="prediction", labelCol="label", + metricName="f1") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + setParams(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + Sets params for multiclass classification evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 69b62f76fced18efa35a107c9be4bc22eba72878 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 30 Jul 2015 23:03:48 -0700 Subject: [PATCH 194/219] [SPARK-9214] [ML] [PySpark] support ml.NaiveBayes for Python support ml.NaiveBayes for Python Author: Yanbo Liang Closes #7568 from yanboliang/spark-9214 and squashes the following commits: 5ee3fd6 [Yanbo Liang] fix typos 3ecd046 [Yanbo Liang] fix typos f9c94d1 [Yanbo Liang] change lambda_ to smoothing and fix other issues 180452a [Yanbo Liang] fix typos 7dda1f4 [Yanbo Liang] support ml.NaiveBayes for Python --- .../spark/ml/classification/NaiveBayes.scala | 10 +- .../classification/JavaNaiveBayesSuite.java | 4 +- .../ml/classification/NaiveBayesSuite.scala | 6 +- python/pyspark/ml/classification.py | 116 +++++++++++++++++- 4 files changed, 125 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 1f547e4a98af7..5be35fe209291 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -38,11 +38,11 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * (default = 1.0). * @group param */ - final val lambda: DoubleParam = new DoubleParam(this, "lambda", "The smoothing parameter.", + final val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "The smoothing parameter.", ParamValidators.gtEq(0)) /** @group getParam */ - final def getLambda: Double = $(lambda) + final def getSmoothing: Double = $(smoothing) /** * The model type which is a string (case-sensitive). @@ -79,8 +79,8 @@ class NaiveBayes(override val uid: String) * Default is 1.0. * @group setParam */ - def setLambda(value: Double): this.type = set(lambda, value) - setDefault(lambda -> 1.0) + def setSmoothing(value: Double): this.type = set(smoothing, value) + setDefault(smoothing -> 1.0) /** * Set the model type using a string (case-sensitive). @@ -92,7 +92,7 @@ class NaiveBayes(override val uid: String) override protected def train(dataset: DataFrame): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) - val oldModel = OldNaiveBayes.train(oldDataset, $(lambda), $(modelType)) + val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 09a9fba0c19cf..a700c9cddb206 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -68,7 +68,7 @@ public void naiveBayesDefaultParams() { assert(nb.getLabelCol() == "label"); assert(nb.getFeaturesCol() == "features"); assert(nb.getPredictionCol() == "prediction"); - assert(nb.getLambda() == 1.0); + assert(nb.getSmoothing() == 1.0); assert(nb.getModelType() == "multinomial"); } @@ -89,7 +89,7 @@ public void testNaiveBayes() { }); DataFrame dataset = jsql.createDataFrame(jrdd, schema); - NaiveBayes nb = new NaiveBayes().setLambda(0.5).setModelType("multinomial"); + NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 76381a2741296..264bde3703c5f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -58,7 +58,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(nb.getLabelCol === "label") assert(nb.getFeaturesCol === "features") assert(nb.getPredictionCol === "prediction") - assert(nb.getLambda === 1.0) + assert(nb.getSmoothing === 1.0) assert(nb.getModelType === "multinomial") } @@ -75,7 +75,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 42, "multinomial")) - val nb = new NaiveBayes().setLambda(1.0).setModelType("multinomial") + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) @@ -101,7 +101,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val testDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 45, "bernoulli")) - val nb = new NaiveBayes().setLambda(1.0).setModelType("bernoulli") + val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli") val model = nb.fit(testDataset) validateModelFit(pi, theta, model) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 5a82bc286d1e8..93ffcd40949b3 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -25,7 +25,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', - 'RandomForestClassifier', 'RandomForestClassificationModel'] + 'RandomForestClassifier', 'RandomForestClassificationModel', 'NaiveBayes', + 'NaiveBayesModel'] @inherit_doc @@ -576,6 +577,119 @@ class GBTClassificationModel(TreeEnsembleModels): """ +@inherit_doc +class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol): + """ + Naive Bayes Classifiers. + + >>> from pyspark.sql import Row + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... Row(label=0.0, features=Vectors.dense([0.0, 0.0])), + ... Row(label=0.0, features=Vectors.dense([0.0, 1.0])), + ... Row(label=1.0, features=Vectors.dense([1.0, 0.0]))]) + >>> nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + >>> model = nb.fit(df) + >>> model.pi + DenseVector([-0.51..., -0.91...]) + >>> model.theta + DenseMatrix(2, 2, [-1.09..., -0.40..., -0.40..., -1.09...], 1) + >>> test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF() + >>> model.transform(test0).head().prediction + 1.0 + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction + 1.0 + """ + + # a placeholder to make it appear in the generated doc + smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " + + "default is 1.0") + modelType = Param(Params._dummy(), "modelType", "The model type which is a string " + + "(case-sensitive). Supported options: multinomial (default) and bernoulli.") + + @keyword_only + def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial"): + """ + __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial") + """ + super(NaiveBayes, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.classification.NaiveBayes", self.uid) + #: param for the smoothing parameter. + self.smoothing = Param(self, "smoothing", "The smoothing parameter, should be >= 0, " + + "default is 1.0") + #: param for the model type. + self.modelType = Param(self, "modelType", "The model type which is a string " + + "(case-sensitive). Supported options: multinomial (default) " + + "and bernoulli.") + self._setDefault(smoothing=1.0, modelType="multinomial") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial"): + """ + setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + smoothing=1.0, modelType="multinomial") + Sets params for Naive Bayes. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return NaiveBayesModel(java_model) + + def setSmoothing(self, value): + """ + Sets the value of :py:attr:`smoothing`. + """ + self._paramMap[self.smoothing] = value + return self + + def getSmoothing(self): + """ + Gets the value of smoothing or its default value. + """ + return self.getOrDefault(self.smoothing) + + def setModelType(self, value): + """ + Sets the value of :py:attr:`modelType`. + """ + self._paramMap[self.modelType] = value + return self + + def getModelType(self): + """ + Gets the value of modelType or its default value. + """ + return self.getOrDefault(self.modelType) + + +class NaiveBayesModel(JavaModel): + """ + Model fitted by NaiveBayes. + """ + + @property + def pi(self): + """ + log of class priors. + """ + return self._call_java("pi") + + @property + def theta(self): + """ + log of class conditional probabilities. + """ + return self._call_java("theta") + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext From 0244170b66476abc4a39ed609a852f1a6fa455e7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 23:05:58 -0700 Subject: [PATCH 195/219] [SPARK-9152][SQL] Implement code generation for Like and RLike JIRA: https://issues.apache.org/jira/browse/SPARK-9152 This PR implements code generation for `Like` and `RLike`. Author: Liang-Chi Hsieh Closes #7561 from viirya/like_rlike_codegen and squashes the following commits: fe5641b [Liang-Chi Hsieh] Add test for NonFoldableLiteral. ccd1b43 [Liang-Chi Hsieh] For comments. 0086723 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen 50df9a8 [Liang-Chi Hsieh] Use nullSafeCodeGen. 8092a68 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen 696d451 [Liang-Chi Hsieh] Check expression foldable. 48e5536 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen aea58e0 [Liang-Chi Hsieh] For comments. 46d946f [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into like_rlike_codegen a0fb76e [Liang-Chi Hsieh] For comments. 6cffe3c [Liang-Chi Hsieh] For comments. 69f0fb6 [Liang-Chi Hsieh] Add code generation for Like and RLike. --- .../expressions/stringOperations.scala | 105 ++++++++++++++---- .../spark/sql/catalyst/util/StringUtils.scala | 47 ++++++++ .../expressions/StringExpressionsSuite.scala | 16 +++ .../sql/catalyst/util/StringUtilsSuite.scala | 34 ++++++ 4 files changed, 180 insertions(+), 22 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 79c0ca56a8e79..99a62343f138d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -21,8 +21,11 @@ import java.text.DecimalFormat import java.util.Locale import java.util.regex.{MatchResult, Pattern} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -160,32 +163,51 @@ trait StringRegexExpression extends ImplicitCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression with CodegenFallback { - // replace the _ with .{1} exactly match 1 time of any character - // replace the % with .*, match 0 or more times with any character - override def escape(v: String): String = - if (!v.isEmpty) { - "(?s)" + (' ' +: v.init).zip(v).flatMap { - case (prev, '\\') => "" - case ('\\', c) => - c match { - case '_' => "_" - case '%' => "%" - case _ => Pattern.quote("\\" + c) - } - case (prev, c) => - c match { - case '_' => "." - case '%' => ".*" - case _ => Pattern.quote(Character.toString(c)) - } - }.mkString - } else { - v - } + override def escape(v: String): String = StringUtils.escapeLikeRegex(v) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() override def toString: String = s"$left LIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).matches(); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile($escapeFunc(rightStr)); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).matches(); + """ + }) + } + } } @@ -195,6 +217,45 @@ case class RLike(left: Expression, right: Expression) override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val patternClass = classOf[Pattern].getName + val pattern = ctx.freshName("pattern") + + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { + val regexStr = + StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString()) + ctx.addMutableState(patternClass, pattern, + s"""$pattern = ${patternClass}.compile("$regexStr");""") + + // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. + val eval = left.gen(ctx) + s""" + ${eval.code} + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.primitive} = $pattern.matcher(${eval.primitive}.toString()).find(0); + } + """ + } else { + s""" + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + """ + } + } else { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + String rightStr = ${eval2}.toString(); + ${patternClass} $pattern = ${patternClass}.compile(rightStr); + ${ev.primitive} = $pattern.matcher(${eval1}.toString()).find(0); + """ + }) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala new file mode 100644 index 0000000000000..9ddfb3a0d3759 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.util.regex.Pattern + +object StringUtils { + + // replace the _ with .{1} exactly match 1 time of any character + // replace the % with .*, match 0 or more times with any character + def escapeLikeRegex(v: String): String = { + if (!v.isEmpty) { + "(?s)" + (' ' +: v.init).zip(v).flatMap { + case (prev, '\\') => "" + case ('\\', c) => + c match { + case '_' => "_" + case '%' => "%" + case _ => Pattern.quote("\\" + c) + } + case (prev, c) => + c match { + case '_' => "." + case '%' => ".*" + case _ => Pattern.quote(Character.toString(c)) + } + }.mkString + } else { + v + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 07b952531ec2e..3ecd0d374c46b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -191,6 +191,15 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, StringType).like("a"), null) checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null) checkEvaluation(Literal.create(null, StringType).like(Literal.create(null, StringType)), null) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create("a", StringType)), true) + checkEvaluation( + Literal.create("a", StringType).like(NonFoldableLiteral.create(null, StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create("a", StringType)), null) + checkEvaluation( + Literal.create(null, StringType).like(NonFoldableLiteral.create(null, StringType)), null) + checkEvaluation("abdef" like "abdef", true) checkEvaluation("a_%b" like "a\\__b", true) checkEvaluation("addb" like "a_%b", true) @@ -232,6 +241,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal.create(null, StringType) rlike "abdef", null) checkEvaluation("abdef" rlike Literal.create(null, StringType), null) checkEvaluation(Literal.create(null, StringType) rlike Literal.create(null, StringType), null) + checkEvaluation("abdef" rlike NonFoldableLiteral.create("abdef", StringType), true) + checkEvaluation("abdef" rlike NonFoldableLiteral.create(null, StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create("abdef", StringType), null) + checkEvaluation( + Literal.create(null, StringType) rlike NonFoldableLiteral.create(null, StringType), null) + checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala new file mode 100644 index 0000000000000..d6f273f9e568a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.StringUtils._ + +class StringUtilsSuite extends SparkFunSuite { + + test("escapeLikeRegex") { + assert(escapeLikeRegex("abdef") === "(?s)\\Qa\\E\\Qb\\E\\Qd\\E\\Qe\\E\\Qf\\E") + assert(escapeLikeRegex("a\\__b") === "(?s)\\Qa\\E_.\\Qb\\E") + assert(escapeLikeRegex("a_%b") === "(?s)\\Qa\\E..*\\Qb\\E") + assert(escapeLikeRegex("a%\\%b") === "(?s)\\Qa\\E.*%\\Qb\\E") + assert(escapeLikeRegex("a%") === "(?s)\\Qa\\E.*") + assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") + assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") + } +} From a3a85d73da053c8e2830759fbc68b734081fa4f3 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Thu, 30 Jul 2015 23:50:06 -0700 Subject: [PATCH 196/219] [SPARK-9496][SQL]do not print the password in config https://issues.apache.org/jira/browse/SPARK-9496 We better do not print the password in log. Author: WangTaoTheTonic Closes #7815 from WangTaoTheTonic/master and squashes the following commits: c7a5145 [WangTaoTheTonic] do not print the password in config --- .../org/apache/spark/sql/hive/client/ClientWrapper.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 8adda54754230..6e0912da5862d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -91,7 +91,11 @@ private[hive] class ClientWrapper( // this action explicit. initialConf.setClassLoader(initClassLoader) config.foreach { case (k, v) => - logDebug(s"Hive Config: $k=$v") + if (k.toLowerCase.contains("password")) { + logDebug(s"Hive Config: $k=xxx") + } else { + logDebug(s"Hive Config: $k=$v") + } initialConf.set(k, v) } val newState = new SessionState(initialConf) From 6bba7509a932aa4d39266df2d15b1370b7aabbec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 31 Jul 2015 08:28:05 -0700 Subject: [PATCH 197/219] [SPARK-9500] add TernaryExpression to simplify ternary expressions There lots of duplicated code in ternary expressions, create a TernaryExpression for them to reduce duplicated code. cc chenghao-intel Author: Davies Liu Closes #7816 from davies/ternary and squashes the following commits: ed2bf76 [Davies Liu] add TernaryExpression --- .../sql/catalyst/expressions/Expression.scala | 85 +++++ .../expressions/codegen/CodeGenerator.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 66 +--- .../expressions/stringOperations.scala | 356 +++++------------- 4 files changed, 183 insertions(+), 326 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8fc182607ce68..2842b3ec5a0c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -432,3 +432,88 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { private[sql] object BinaryOperator { def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } + +/** + * An expression with three inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. + */ +abstract class TernaryExpression extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + /** + * Default behavior of evaluation according to the default nullability of BinaryExpression. + * If subclass of BinaryExpression override nullable, probably should also override this. + */ + override def eval(input: InternalRow): Any = { + val exprs = children + val value1 = exprs(0).eval(input) + if (value1 != null) { + val value2 = exprs(1).eval(input) + if (value2 != null) { + val value3 = exprs(2).eval(input) + if (value3 != null) { + return nullSafeEval(value1, value2, value3) + } + } + } + null + } + + /** + * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default + * nullability, they can override this method to save null-check code. If we need full control + * of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = + sys.error(s"BinaryExpressions must override either eval or nullSafeEval") + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { + s"${ev.primitive} = ${f(eval1, eval2, eval3)};" + }) + } + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f function that accepts the 2 non-null evaluation result names of children + * and returns Java code to compute the output. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val evals = children.map(_.gen(ctx)) + val resultCode = f(evals(0).primitive, evals(1).primitive, evals(2).primitive) + s""" + ${evals(0).code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${evals(0).isNull}) { + ${evals(1).code} + if (!${evals(1).isNull}) { + ${evals(2).code} + if (!${evals(2).isNull}) { + ${ev.isNull} = false; // resultCode could change nullability + $resultCode + } + } + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 60e2863f7bbb0..e50ec27fc2eb6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -305,7 +305,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin evaluator.cook(code) } catch { case e: Exception => - val msg = "failed to compile:\n " + CodeFormatter.format(code) + val msg = s"failed to compile: $e\n" + CodeFormatter.format(code) logError(msg, e) throw new Exception(msg, e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index e6d807f6d897b..15ceb9193a8c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -165,69 +165,29 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" * @param toBaseExpr to which base */ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends Expression with ImplicitCastInputTypes { - - override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable - - override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) - override def dataType: DataType = StringType - /** Returns the result of evaluating this expression on a given input Row */ - override def eval(input: InternalRow): Any = { - val num = numExpr.eval(input) - if (num != null) { - val fromBase = fromBaseExpr.eval(input) - if (fromBase != null) { - val toBase = toBaseExpr.eval(input) - if (toBase != null) { - NumberConverter.convert( - num.asInstanceOf[UTF8String].getBytes, - fromBase.asInstanceOf[Int], - toBase.asInstanceOf[Int]) - } else { - null - } - } else { - null - } - } else { - null - } + override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { + NumberConverter.convert( + num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], + toBase.asInstanceOf[Int]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val numGen = numExpr.gen(ctx) - val from = fromBaseExpr.gen(ctx) - val to = toBaseExpr.gen(ctx) - val numconv = NumberConverter.getClass.getName.stripSuffix("$") - s""" - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${numGen.code} - boolean ${ev.isNull} = ${numGen.isNull}; - if (!${ev.isNull}) { - ${from.code} - if (!${from.isNull}) { - ${to.code} - if (!${to.isNull}) { - ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(), - ${from.primitive}, ${to.primitive}); - if (${ev.primitive} == null) { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } + nullSafeCodeGen(ctx, ev, (num, from, to) => + s""" + ${ev.primitive} = $numconv.convert($num.getBytes(), $from, $to); + if (${ev.primitive} == null) { + ${ev.isNull} = true; } - """ + """ + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 99a62343f138d..684eac12bd6f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -426,15 +426,13 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) } override def children: Seq[Expression] = substr :: str :: start :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -467,60 +465,18 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = children.exists(_.nullable) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def eval(input: InternalRow): Any = { - val s = str.eval(input) - if (s == null) { - null - } else { - val l = len.eval(input) - if (l == null) { - null - } else { - val p = pad.eval(input) - if (p == null) { - null - } else { - val len = l.asInstanceOf[Int] - val str = s.asInstanceOf[UTF8String] - val pad = p.asInstanceOf[UTF8String] - - str.lpad(len, pad) - } - } - } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { + str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val lenGen = len.gen(ctx) - val strGen = str.gen(ctx) - val padGen = pad.gen(ctx) - - s""" - ${lenGen.code} - boolean ${ev.isNull} = ${lenGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${strGen.code} - if (!${strGen.isNull}) { - ${padGen.code} - if (!${padGen.isNull}) { - ${ev.primitive} = ${strGen.primitive}.lpad(${lenGen.primitive}, ${padGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } override def prettyName: String = "lpad" @@ -530,60 +486,18 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = children.exists(_.nullable) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def eval(input: InternalRow): Any = { - val s = str.eval(input) - if (s == null) { - null - } else { - val l = len.eval(input) - if (l == null) { - null - } else { - val p = pad.eval(input) - if (p == null) { - null - } else { - val len = l.asInstanceOf[Int] - val str = s.asInstanceOf[UTF8String] - val pad = p.asInstanceOf[UTF8String] - - str.rpad(len, pad) - } - } - } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { + str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val lenGen = len.gen(ctx) - val strGen = str.gen(ctx) - val padGen = pad.gen(ctx) - - s""" - ${lenGen.code} - boolean ${ev.isNull} = ${lenGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${strGen.code} - if (!${strGen.isNull}) { - ${padGen.code} - if (!${padGen.isNull}) { - ${ev.primitive} = ${strGen.primitive}.rpad(${lenGen.primitive}, ${padGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } override def prettyName: String = "rpad" @@ -745,68 +659,24 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) } - override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - override def nullable: Boolean = str.nullable || pos.nullable || len.nullable - override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil - override def eval(input: InternalRow): Any = { - val stringEval = str.eval(input) - if (stringEval != null) { - val posEval = pos.eval(input) - if (posEval != null) { - val lenEval = len.eval(input) - if (lenEval != null) { - stringEval.asInstanceOf[UTF8String] - .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int]) - } else { - null - } - } else { - null - } - } else { - null - } + override def nullSafeEval(string: Any, pos: Any, len: Any): Any = { + string.asInstanceOf[UTF8String].substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val strGen = str.gen(ctx) - val posGen = pos.gen(ctx) - val lenGen = len.gen(ctx) - - val start = ctx.freshName("start") - val end = ctx.freshName("end") - - s""" - ${strGen.code} - boolean ${ev.isNull} = ${strGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${posGen.code} - if (!${posGen.isNull}) { - ${lenGen.code} - if (!${lenGen.isNull}) { - ${ev.primitive} = ${strGen.primitive} - .substringSQL(${posGen.primitive}, ${lenGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, pos, len) => s"$str.substringSQL($pos, $len)") } } @@ -986,7 +856,7 @@ case class Encode(value: Expression, charset: Expression) * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { // last regex in string, we will update the pattern iff regexp value changed. @transient private var lastRegex: UTF8String = _ @@ -998,40 +868,26 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio // result buffer write by Matcher @transient private val result: StringBuffer = new StringBuffer - override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable - override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable - - override def eval(input: InternalRow): Any = { - val s = subject.eval(input) - if (null != s) { - val p = regexp.eval(input) - if (null != p) { - val r = rep.eval(input) - if (null != r) { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String] - pattern = Pattern.compile(lastRegex.toString) - } - if (!r.equals(lastReplacementInUTF8)) { - // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String] - lastReplacement = lastReplacementInUTF8.toString - } - val m = pattern.matcher(s.toString()) - result.delete(0, result.length()) - - while (m.find) { - m.appendReplacement(result, lastReplacement) - } - m.appendTail(result) + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String] + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) - return UTF8String.fromString(result.toString) - } - } + while (m.find) { + m.appendReplacement(result, lastReplacement) } + m.appendTail(result) - null + UTF8String.fromString(result.toString) } override def dataType: DataType = StringType @@ -1048,59 +904,43 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val termResult = ctx.freshName("result") - val classNameUTF8String = classOf[UTF8String].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName - val classNameString = classOf[java.lang.String].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName - ctx.addMutableState(classNameUTF8String, + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - ctx.addMutableState(classNameString, + ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState(classNameUTF8String, + ctx.addMutableState("UTF8String", termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") ctx.addMutableState(classNameStringBuffer, termResult, s"${termResult} = new $classNameStringBuffer();") - val evalSubject = subject.gen(ctx) - val evalRegexp = regexp.gen(ctx) - val evalRep = rep.gen(ctx) - + nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" - ${evalSubject.code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${evalSubject.isNull}) { - ${evalRegexp.code} - if (!${evalRegexp.isNull}) { - ${evalRep.code} - if (!${evalRep.isNull}) { - if (!${evalRegexp.primitive}.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = ${evalRegexp.primitive}; - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) { - // replacement string changed - ${termLastReplacementInUTF8} = ${evalRep.primitive}; - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); - } - ${termResult}.delete(0, ${termResult}.length()); - ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!$rep.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = $rep; + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); - while (m.find()) { - m.appendReplacement(${termResult}, ${termLastReplacement}); - } - m.appendTail(${termResult}); - ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString()); - ${ev.isNull} = false; - } - } + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); } + m.appendTail(${termResult}); + ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); + ${ev.isNull} = false; """ + }) } } @@ -1110,7 +950,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) // last regex in string, we will update the pattern iff regexp value changed. @@ -1118,32 +958,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio // last regex pattern, we cache it for performance concern @transient private var pattern: Pattern = _ - override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable - override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable - - override def eval(input: InternalRow): Any = { - val s = subject.eval(input) - if (null != s) { - val p = regexp.eval(input) - if (null != p) { - val r = idx.eval(input) - if (null != r) { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String] - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString()) - if (m.find) { - val mr: MatchResult = m.toMatchResult - return UTF8String.fromString(mr.group(r.asInstanceOf[Int])) - } - return UTF8String.EMPTY_UTF8 - } - } + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString()) + if (m.find) { + val mr: MatchResult = m.toMatchResult + UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } else { + UTF8String.EMPTY_UTF8 } - - null } override def dataType: DataType = StringType @@ -1154,44 +981,29 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") - val classNameUTF8String = classOf[UTF8String].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName - ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - val evalSubject = subject.gen(ctx) - val evalRegexp = regexp.gen(ctx) - val evalIdx = idx.gen(ctx) - - s""" - ${evalSubject.code} - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - boolean ${ev.isNull} = true; - if (!${evalSubject.isNull}) { - ${evalRegexp.code} - if (!${evalRegexp.isNull}) { - ${evalIdx.code} - if (!${evalIdx.isNull}) { - if (!${evalRegexp.primitive}.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = ${evalRegexp.primitive}; - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); - if (m.find()) { - ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); - ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); - ${ev.isNull} = false; - } else { - ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8; - ${ev.isNull} = false; - } - } - } + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } - """ + java.util.regex.Matcher m = + ${termPattern}.matcher($subject.toString()); + if (m.find()) { + java.util.regex.MatchResult mr = m.toMatchResult(); + ${ev.primitive} = UTF8String.fromString(mr.group($idx)); + ${ev.isNull} = false; + } else { + ${ev.primitive} = UTF8String.EMPTY_UTF8; + ${ev.isNull} = false; + }""" + }) } } From fc0e57e5aba82a3f227fef05a843283e2ec893fc Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 31 Jul 2015 09:33:38 -0700 Subject: [PATCH 198/219] [SPARK-9053] [SPARKR] Fix spaces around parens, infix operators etc. ### JIRA [[SPARK-9053] Fix spaces around parens, infix operators etc. - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9053) ### The Result of `lint-r` [The result of lint-r at the rivision:a4c83cb1e4b066cd60264b6572fd3e51d160d26a](https://gist.github.com/yu-iskw/d253d7f8ef351f86443d) Author: Yu ISHIKAWA Closes #7584 from yu-iskw/SPARK-9053 and squashes the following commits: 613170f [Yu ISHIKAWA] Ignore a warning about a space before a left parentheses ede61e1 [Yu ISHIKAWA] Ignores two warnings about a space before a left parentheses. TODO: After updating `lintr`, we will remove the ignores de3e0db [Yu ISHIKAWA] Add '## nolint start' & '## nolint end' statement to ignore infix space warnings e233ea8 [Yu ISHIKAWA] [SPARK-9053][SparkR] Fix spaces around parens, infix operators etc. --- R/pkg/R/DataFrame.R | 4 ++++ R/pkg/R/RDD.R | 7 +++++-- R/pkg/R/column.R | 2 +- R/pkg/R/context.R | 2 +- R/pkg/R/pairRDD.R | 2 +- R/pkg/R/utils.R | 4 ++-- R/pkg/inst/tests/test_binary_function.R | 2 +- R/pkg/inst/tests/test_rdd.R | 6 +++--- R/pkg/inst/tests/test_sparkSQL.R | 4 +++- 9 files changed, 21 insertions(+), 12 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index f4c93d3c7dd67..b31ad3729e09b 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1322,9 +1322,11 @@ setMethod("write.df", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) if (!is.null(path)) { @@ -1384,9 +1386,11 @@ setMethod("saveAsTable", "org.apache.spark.sql.parquet") } allModes <- c("append", "overwrite", "error", "ignore") + # nolint start if (!(mode %in% allModes)) { stop('mode should be one of "append", "overwrite", "error", "ignore"') } + # nolint end jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) options <- varargsToEnv(...) callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index d2d096709245d..2a013b3dbb968 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -85,7 +85,9 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) isPipelinable <- function(rdd) { e <- rdd@env + # nolint start !(e$isCached || e$isCheckpointed) + # nolint end } if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) { @@ -97,7 +99,8 @@ setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) # prev_serializedMode is used during the delayed computation of JRDD in getJRDD } else { pipelinedFunc <- function(partIndex, part) { - func(partIndex, prev@func(partIndex, part)) + f <- prev@func + func(partIndex, f(partIndex, part)) } .Object@func <- cleanClosure(pipelinedFunc) .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline @@ -841,7 +844,7 @@ setMethod("sampleRDD", if (withReplacement) { count <- rpois(1, fraction) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 2892e1416cc65..eeaf9f193b728 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -65,7 +65,7 @@ functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt", "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp", "expm1", "floor", "log", "log10", "log1p", "rint", "sign", "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians") -binary_mathfunctions<- c("atan2", "hypot") +binary_mathfunctions <- c("atan2", "hypot") createOperator <- function(op) { setMethod(op, diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 43be9c904fdf6..720990e1c6087 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -121,7 +121,7 @@ parallelize <- function(sc, coll, numSlices = 1) { numSlices <- length(coll) sliceLen <- ceiling(length(coll) / numSlices) - slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)]) + slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)]) # Serialize each slice: obtain a list of raws, or a list of lists (slices) of # 2-tuples of raws diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 83801d3209700..199c3fd6ab1b2 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -879,7 +879,7 @@ setMethod("sampleByKey", if (withReplacement) { count <- rpois(1, frac) if (count > 0) { - res[(len + 1):(len + count)] <- rep(list(elem), count) + res[ (len + 1) : (len + count) ] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 3f45589a50443..4f9f4d9cad2a8 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -32,7 +32,7 @@ convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, } results <- if (arrSize > 0) { - lapply(0:(arrSize - 1), + lapply(0 : (arrSize - 1), function(index) { obj <- callJMethod(jList, "get", as.integer(index)) @@ -572,7 +572,7 @@ mergePartitions <- function(rdd, zip) { keys <- list() } if (lengthOfValues > 1) { - values <- part[(lengthOfKeys + 1) : (len - 1)] + values <- part[ (lengthOfKeys + 1) : (len - 1) ] } else { values <- list() } diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index dca0657c57e0d..f054ac9a87d61 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -40,7 +40,7 @@ test_that("union on two RDDs", { expect_equal(actual, c(as.list(nums), mockFile)) expect_equal(getSerializedMode(union.rdd), "byte") - rdd<- map(text.rdd, function(x) {x}) + rdd <- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 6c3aaab8c711e..71aed2bb9d6a8 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -250,7 +250,7 @@ test_that("flatMapValues() on pairwise RDDs", { expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4))) # Generate x to x+1 for every value - actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) })) + actual <- collect(flatMapValues(intRdd, function(x) { x: (x + 1) })) expect_equal(actual, list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101), list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201))) @@ -293,7 +293,7 @@ test_that("sumRDD() on RDDs", { }) test_that("keyBy on RDDs", { - func <- function(x) { x*x } + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collect(keys) expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) @@ -311,7 +311,7 @@ test_that("repartition/coalesce on RDDs", { r2 <- repartition(rdd, 6) expect_equal(numPartitions(r2), 6L) count <- length(collectPartition(r2, 0L)) - expect_true(count >=0 && count <= 4) + expect_true(count >= 0 && count <= 4) # coalesce r3 <- coalesce(rdd, 1) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 61c8a7ec7d837..aca41aa6dcf24 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -666,10 +666,12 @@ test_that("column binary mathfunctions", { expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6)) expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7)) expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8)) + ## nolint start expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) + ## nolint end }) test_that("string operators", { @@ -876,7 +878,7 @@ test_that("parquetFile works with multiple input paths", { write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) expect_is(parquetDF, "DataFrame") - expect_equal(count(parquetDF), count(df)*2) + expect_equal(count(parquetDF), count(df) * 2) }) test_that("describe() on a DataFrame", { From 04a49edfdb606c01fa4f8ae6e730ec4f9bd0cb6d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 31 Jul 2015 09:34:10 -0700 Subject: [PATCH 199/219] [SPARK-9497] [SPARK-9509] [CORE] Use ask instead of askWithRetry `RpcEndpointRef.askWithRetry` throws `SparkException` rather than `TimeoutException`. Use ask to replace it because we don't need to retry here. Author: zsxwing Closes #7824 from zsxwing/SPARK-9497 and squashes the following commits: 7bfc2b4 [zsxwing] Use ask instead of askWithRetry --- .../scala/org/apache/spark/deploy/client/AppClient.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 79b251e7e62fe..a659abf70395d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.rpc._ -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -248,7 +248,8 @@ private[spark] class AppClient( def stop() { if (endpoint != null) { try { - endpoint.askWithRetry[Boolean](StopAppClient) + val timeout = RpcUtils.askRpcTimeout(conf) + timeout.awaitResult(endpoint.ask[Boolean](StopAppClient)) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") From 27ae851ce16082775ffbcb5b8fc6bdbe65dc70fc Mon Sep 17 00:00:00 2001 From: tedyu Date: Fri, 31 Jul 2015 18:16:55 +0100 Subject: [PATCH 200/219] [SPARK-9446] Clear Active SparkContext in stop() method In thread 'stopped SparkContext remaining active' on mailing list, Andres observed the following in driver log: ``` 15/07/29 15:17:09 WARN YarnSchedulerBackend$YarnSchedulerEndpoint: ApplicationMaster has disassociated:
15/07/29 15:17:09 INFO YarnClientSchedulerBackend: Shutting down all executors Exception in thread "Yarn application state monitor" org.apache.spark.SparkException: Error asking standalone scheduler to shut down executors at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stopExecutors(CoarseGrainedSchedulerBackend.scala:261) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stop(CoarseGrainedSchedulerBackend.scala:266) at org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend.stop(YarnClientSchedulerBackend.scala:158) at org.apache.spark.scheduler.TaskSchedulerImpl.stop(TaskSchedulerImpl.scala:416) at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:1411) at org.apache.spark.SparkContext.stop(SparkContext.scala:1644) at org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend$$anon$1.run(YarnClientSchedulerBackend.scala:139) Caused by: java.lang.InterruptedException at java.util.concurrent.locks.AbstractQueuedSynchronizer.tryAcquireSharedNanos(AbstractQueuedSynchronizer.java:1325) at scala.concurrent.impl.Promise$DefaultPromise.tryAwait(Promise.scala:208) at scala.concurrent.impl.Promise$DefaultPromise.ready(Promise.scala:218) at scala.concurrent.impl.Promise$DefaultPromise.result(Promise.scala:223) at scala.concurrent.Await$$anonfun$result$1.apply(package.scala:190) at scala.concurrent.BlockContext$DefaultBlockContext$.blockOn(BlockContext.scala:53) at scala.concurrent.Await$.result(package.scala:190)15/07/29 15:17:09 INFO YarnClientSchedulerBackend: Asking each executor to shut down at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:102) at org.apache.spark.rpc.RpcEndpointRef.askWithRetry(RpcEndpointRef.scala:78) at org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.stopExecutors(CoarseGrainedSchedulerBackend.scala:257) ... 6 more ``` Effect of the above exception is that a stopped SparkContext is returned to user since SparkContext.clearActiveContext() is not called. Author: tedyu Closes #7756 from tedyu/master and squashes the following commits: 7339ff2 [tedyu] Move null assignment out of tryLogNonFatalError block 6e02cd9 [tedyu] Use Utils.tryLogNonFatalError to guard resource release f5fb519 [tedyu] Clear Active SparkContext in stop() method using finally --- .../scala/org/apache/spark/SparkContext.scala | 50 ++++++++++++++----- 1 file changed, 37 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ac6ac6c216767..2d8aa25d81daa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1689,33 +1689,57 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli Utils.removeShutdownHook(_shutdownHookRef) } - postApplicationEnd() - _ui.foreach(_.stop()) + Utils.tryLogNonFatalError { + postApplicationEnd() + } + Utils.tryLogNonFatalError { + _ui.foreach(_.stop()) + } if (env != null) { - env.metricsSystem.report() + Utils.tryLogNonFatalError { + env.metricsSystem.report() + } } if (metadataCleaner != null) { - metadataCleaner.cancel() + Utils.tryLogNonFatalError { + metadataCleaner.cancel() + } + } + Utils.tryLogNonFatalError { + _cleaner.foreach(_.stop()) + } + Utils.tryLogNonFatalError { + _executorAllocationManager.foreach(_.stop()) } - _cleaner.foreach(_.stop()) - _executorAllocationManager.foreach(_.stop()) if (_dagScheduler != null) { - _dagScheduler.stop() + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } _dagScheduler = null } if (_listenerBusStarted) { - listenerBus.stop() - _listenerBusStarted = false + Utils.tryLogNonFatalError { + listenerBus.stop() + _listenerBusStarted = false + } + } + Utils.tryLogNonFatalError { + _eventLogger.foreach(_.stop()) } - _eventLogger.foreach(_.stop()) if (env != null && _heartbeatReceiver != null) { - env.rpcEnv.stop(_heartbeatReceiver) + Utils.tryLogNonFatalError { + env.rpcEnv.stop(_heartbeatReceiver) + } + } + Utils.tryLogNonFatalError { + _progressBar.foreach(_.stop()) } - _progressBar.foreach(_.stop()) _taskScheduler = null // TODO: Cache.stop()? if (_env != null) { - _env.stop() + Utils.tryLogNonFatalError { + _env.stop() + } SparkEnv.set(null) } SparkContext.clearActiveContext() From 0024da9157ba12ec84883a78441fa6835c1d0042 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 31 Jul 2015 11:07:34 -0700 Subject: [PATCH 201/219] [SQL] address comments for to_date/trunc This PR address the comments in #7805 cc rxin Author: Davies Liu Closes #7817 from davies/trunc and squashes the following commits: f729d5f [Davies Liu] rollback cb7f7832 [Davies Liu] genCode() is protected 31e52ef [Davies Liu] fix style ed1edc7 [Davies Liu] address comments for #7805 --- .../catalyst/expressions/datetimeFunctions.scala | 15 ++++++++------- .../spark/sql/catalyst/util/DateTimeUtils.scala | 3 ++- .../expressions/ExpressionEvalHelper.scala | 4 +--- .../scala/org/apache/spark/sql/functions.scala | 3 +++ 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala index 6e7613340c032..07dea5b470b5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeFunctions.scala @@ -726,15 +726,16 @@ case class TruncDate(date: Expression, format: Expression) override def dataType: DataType = DateType override def prettyName: String = "trunc" - lazy val minItemConst = DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) + private lazy val truncLevel: Int = + DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) override def eval(input: InternalRow): Any = { - val minItem = if (format.foldable) { - minItemConst + val level = if (format.foldable) { + truncLevel } else { DateTimeUtils.parseTruncLevel(format.eval().asInstanceOf[UTF8String]) } - if (minItem == -1) { + if (level == -1) { // unknown format null } else { @@ -742,7 +743,7 @@ case class TruncDate(date: Expression, format: Expression) if (d == null) { null } else { - DateTimeUtils.truncDate(d.asInstanceOf[Int], minItem) + DateTimeUtils.truncDate(d.asInstanceOf[Int], level) } } } @@ -751,7 +752,7 @@ case class TruncDate(date: Expression, format: Expression) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { - if (minItemConst == -1) { + if (truncLevel == -1) { s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; @@ -763,7 +764,7 @@ case class TruncDate(date: Expression, format: Expression) boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = $dtu.truncDate(${d.primitive}, $minItemConst); + ${ev.primitive} = $dtu.truncDate(${d.primitive}, $truncLevel); } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 5a7c25b8d508d..032ed8a56a50e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -794,7 +794,8 @@ object DateTimeUtils { } else if (level == TRUNC_TO_MONTH) { d - DateTimeUtils.getDayOfMonth(d) + 1 } else { - throw new Exception(s"Invalid trunc level: $level") + // caller make sure that this should never be reached + sys.error(s"Invalid trunc level: $level") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 3c05e5c3b833c..a41185b4d8754 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.expressions import org.scalactic.TripleEqualsSupport.Spread -import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 46dc4605a5ccb..5d82a5eadd94d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2192,6 +2192,9 @@ object functions { /** * Returns date truncated to the unit specified by the format. * + * @param format: 'year', 'yyyy', 'yy' for truncate by year, + * or 'month', 'mon', 'mm' for truncate by month + * * @group datetime_funcs * @since 1.5.0 */ From 6add4eddb39e7748a87da3e921ea3c7881d30a82 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Fri, 31 Jul 2015 11:22:40 -0700 Subject: [PATCH 202/219] [SPARK-9471] [ML] Multilayer Perceptron This pull request contains the following feature for ML: - Multilayer Perceptron classifier This implementation is based on our initial pull request with bgreeven: https://github.com/apache/spark/pull/1290 and inspired by very insightful suggestions from mengxr and witgo (I would like to thank all other people from the mentioned thread for useful discussions). The original code was extensively tested and benchmarked. Since then, I've addressed two main requirements that prevented the code from merging into the main branch: - Extensible interface, so it will be easy to implement new types of networks - Main building blocks are traits `Layer` and `LayerModel`. They are used for constructing layers of ANN. New layers can be added by extending the `Layer` and `LayerModel` traits. These traits are private in this release in order to save path to improve them based on community feedback - Back propagation is implemented in general form, so there is no need to change it (optimization algorithm) when new layers are implemented - Speed and scalability: this implementation has to be comparable in terms of speed to the state of the art single node implementations. - The developed benchmark for large ANN shows that the proposed code is on par with C++ CPU implementation and scales nicely with the number of workers. Details can be found here: https://github.com/avulanov/ann-benchmark - DBN and RBM by witgo https://github.com/witgo/spark/tree/ann-interface-gemm-dbn - Dropout https://github.com/avulanov/spark/tree/ann-interface-gemm mengxr and dbtsai kindly agreed to perform code review. Author: Alexander Ulanov Author: Bert Greevenbosch Closes #7621 from avulanov/SPARK-2352-ann and squashes the following commits: 4806b6f [Alexander Ulanov] Addressing reviewers comments. a7e7951 [Alexander Ulanov] Default blockSize: 100. Added documentation to blockSize parameter and DataStacker class f69bb3d [Alexander Ulanov] Addressing reviewers comments. 374bea6 [Alexander Ulanov] Moving ANN to ML package. GradientDescent constructor is now spark private. 43b0ae2 [Alexander Ulanov] Addressing reviewers comments. Adding multiclass test. 9d18469 [Alexander Ulanov] Addressing reviewers comments: unnecessary copy of data in predict 35125ab [Alexander Ulanov] Style fix in tests e191301 [Alexander Ulanov] Apache header a226133 [Alexander Ulanov] Multilayer Perceptron regressor and classifier --- .../org/apache/spark/ml/ann/BreezeUtil.scala | 63 ++ .../scala/org/apache/spark/ml/ann/Layer.scala | 882 ++++++++++++++++++ .../MultilayerPerceptronClassifier.scala | 193 ++++ .../org/apache/spark/ml/param/params.scala | 5 + .../mllib/optimization/GradientDescent.scala | 2 +- .../org/apache/spark/ml/ann/ANNSuite.scala | 91 ++ .../MultilayerPerceptronClassifierSuite.scala | 91 ++ 7 files changed, 1326 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala new file mode 100644 index 0000000000000..7429f9d652ac5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/BreezeUtil.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} +import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} + +/** + * In-place DGEMM and DGEMV for Breeze + */ +private[ann] object BreezeUtil { + + // TODO: switch to MLlib BLAS interface + private def transposeString(a: BDM[Double]): String = if (a.isTranspose) "T" else "N" + + /** + * DGEMM: C := alpha * A * B + beta * C + * @param alpha alpha + * @param a A + * @param b B + * @param beta beta + * @param c C + */ + def dgemm(alpha: Double, a: BDM[Double], b: BDM[Double], beta: Double, c: BDM[Double]): Unit = { + // TODO: add code if matrices isTranspose!!! + require(a.cols == b.rows, "A & B Dimension mismatch!") + require(a.rows == c.rows, "A & C Dimension mismatch!") + require(b.cols == c.cols, "A & C Dimension mismatch!") + NativeBLAS.dgemm(transposeString(a), transposeString(b), c.rows, c.cols, a.cols, + alpha, a.data, a.offset, a.majorStride, b.data, b.offset, b.majorStride, + beta, c.data, c.offset, c.rows) + } + + /** + * DGEMV: y := alpha * A * x + beta * y + * @param alpha alpha + * @param a A + * @param x x + * @param beta beta + * @param y y + */ + def dgemv(alpha: Double, a: BDM[Double], x: BDV[Double], beta: Double, y: BDV[Double]): Unit = { + require(a.cols == x.length, "A & b Dimension mismatch!") + NativeBLAS.dgemv(transposeString(a), a.rows, a.cols, + alpha, a.data, a.offset, a.majorStride, x.data, x.offset, x.stride, + beta, y.data, y.offset, y.stride) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala new file mode 100644 index 0000000000000..b5258ff348477 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, + sum => Bsum} +import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.optimization._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * Trait that holds Layer properties, that are needed to instantiate it. + * Implements Layer instantiation. + * + */ +private[ann] trait Layer extends Serializable { + /** + * Returns the instance of the layer based on weights provided + * @param weights vector with layer weights + * @param position position of weights in the vector + * @return the layer model + */ + def getInstance(weights: Vector, position: Int): LayerModel + + /** + * Returns the instance of the layer with random generated weights + * @param seed seed + * @return the layer model + */ + def getInstance(seed: Long): LayerModel +} + +/** + * Trait that holds Layer weights (or parameters). + * Implements functions needed for forward propagation, computing delta and gradient. + * Can return weights in Vector format. + */ +private[ann] trait LayerModel extends Serializable { + /** + * number of weights + */ + val size: Int + + /** + * Evaluates the data (process the data through the layer) + * @param data data + * @return processed data + */ + def eval(data: BDM[Double]): BDM[Double] + + /** + * Computes the delta for back propagation + * @param nextDelta delta of the next layer + * @param input input data + * @return delta + */ + def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] + + /** + * Computes the gradient + * @param delta delta for this layer + * @param input input data + * @return gradient + */ + def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] + + /** + * Returns weights for the layer in a single vector + * @return layer weights + */ + def weights(): Vector +} + +/** + * Layer properties of affine transformations, that is y=A*x+b + * @param numIn number of inputs + * @param numOut number of outputs + */ +private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer { + + override def getInstance(weights: Vector, position: Int): LayerModel = { + AffineLayerModel(this, weights, position) + } + + override def getInstance(seed: Long = 11L): LayerModel = { + AffineLayerModel(this, seed) + } +} + +/** + * Model of Affine layer y=A*x+b + * @param w weights (matrix A) + * @param b bias (vector b) + */ +private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel { + val size = w.size + b.length + val gwb = new Array[Double](size) + private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb) + private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size) + private var z: BDM[Double] = null + private var d: BDM[Double] = null + private var ones: BDV[Double] = null + + override def eval(data: BDM[Double]): BDM[Double] = { + if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols) + z(::, *) := b + BreezeUtil.dgemm(1.0, w, data, 1.0, z) + z + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols) + BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d) + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = { + BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw) + if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols) + BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb) + gwb + } + + override def weights(): Vector = AffineLayerModel.roll(w, b) +} + +/** + * Fabric for Affine layer models + */ +private[ann] object AffineLayerModel { + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param weights vector with weights + * @param position position of weights in the vector + * @return model of Affine layer + */ + def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = { + val (w, b) = unroll(weights, position, layer.numIn, layer.numOut) + new AffineLayerModel(w, b) + } + + /** + * Creates a model of Affine layer + * @param layer layer properties + * @param seed seed + * @return model of Affine layer + */ + def apply(layer: AffineLayer, seed: Long): AffineLayerModel = { + val (w, b) = randomWeights(layer.numIn, layer.numOut, seed) + new AffineLayerModel(w, b) + } + + /** + * Unrolls the weights from the vector + * @param weights vector with weights + * @param position position of weights for this layer + * @param numIn number of layer inputs + * @param numOut number of layer outputs + * @return matrix A and vector b + */ + def unroll( + weights: Vector, + position: Int, + numIn: Int, + numOut: Int): (BDM[Double], BDV[Double]) = { + val weightsCopy = weights.toArray + // TODO: the array is not copied to BDMs, make sure this is OK! + val a = new BDM[Double](numOut, numIn, weightsCopy, position) + val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut) + (a, b) + } + + /** + * Roll the layer weights into a vector + * @param a matrix A + * @param b vector b + * @return vector of weights + */ + def roll(a: BDM[Double], b: BDV[Double]): Vector = { + val result = new Array[Double](a.size + b.length) + // TODO: make sure that we need to copy! + System.arraycopy(a.toArray, 0, result, 0, a.size) + System.arraycopy(b.toArray, 0, result, a.size, b.length) + Vectors.dense(result) + } + + /** + * Generate random weights for the layer + * @param numIn number of inputs + * @param numOut number of outputs + * @param seed seed + * @return (matrix A, vector b) + */ + def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = { + val rand: XORShiftRandom = new XORShiftRandom(seed) + val weights = BDM.fill[Double](numOut, numIn){ (rand.nextDouble * 4.8 - 2.4) / numIn } + val bias = BDV.fill[Double](numOut){ (rand.nextDouble * 4.8 - 2.4) / numIn } + (weights, bias) + } +} + +/** + * Trait for functions and their derivatives for functional layers + */ +private[ann] trait ActivationFunction extends Serializable { + + /** + * Implements a function + * @param x input data + * @param y output data + */ + def eval(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a derivative of a function (needed for the back propagation) + * @param x input data + * @param y output data + */ + def derivative(x: BDM[Double], y: BDM[Double]): Unit + + /** + * Implements a cross entropy error of a function. + * Needed if the functional layer that contains this function is the output layer + * of the network. + * @param target target output + * @param output computed output + * @param result intermediate result + * @return cross-entropy + */ + def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double + + /** + * Implements a mean squared error of a function + * @param target target output + * @param output computed output + * @param result intermediate result + * @return mean squared error + */ + def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double +} + +/** + * Implements in-place application of functions + */ +private[ann] object ActivationFunction { + + def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = { + var i = 0 + while (i < x.rows) { + var j = 0 + while (j < x.cols) { + y(i, j) = func(x(i, j)) + j += 1 + } + i += 1 + } + } + + def apply( + x1: BDM[Double], + x2: BDM[Double], + y: BDM[Double], + func: (Double, Double) => Double): Unit = { + var i = 0 + while (i < x1.rows) { + var j = 0 + while (j < x1.cols) { + y(i, j) = func(x1(i, j), x2(i, j)) + j += 1 + } + i += 1 + } + } +} + +/** + * Implements SoftMax activation function + */ +private[ann] class SoftmaxFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + var j = 0 + // find max value to make sure later that exponent is computable + while (j < x.cols) { + var i = 0 + var max = Double.MinValue + while (i < x.rows) { + if (x(i, j) > max) { + max = x(i, j) + } + i += 1 + } + var sum = 0.0 + i = 0 + while (i < x.rows) { + val res = Math.exp(x(i, j) - max) + y(i, j) = res + sum += res + i += 1 + } + i = 0 + while (i < x.rows) { + y(i, j) /= sum + i += 1 + } + j += 1 + } + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum( target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.") + } +} + +/** + * Implements Sigmoid activation function + */ +private[ann] class SigmoidFunction extends ActivationFunction { + override def eval(x: BDM[Double], y: BDM[Double]): Unit = { + def s(z: Double): Double = Bsigmoid(z) + ActivationFunction(x, y, s) + } + + override def crossEntropy( + output: BDM[Double], + target: BDM[Double], + result: BDM[Double]): Double = { + def m(o: Double, t: Double): Double = o - t + ActivationFunction(output, target, result, m) + -Bsum(target :* Blog(output)) / output.cols + } + + override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { + def sd(z: Double): Double = (1 - z) * z + ActivationFunction(x, y, sd) + } + + override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { + // TODO: make it readable + def m(o: Double, t: Double): Double = (o - t) + ActivationFunction(output, target, result, m) + val e = Bsum(result :* result) / 2 / output.cols + def m2(x: Double, o: Double) = x * (o - o * o) + ActivationFunction(result, output, result, m2) + e + } +} + +/** + * Functional layer properties, y = f(x) + * @param activationFunction activation function + */ +private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer { + override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L) + + override def getInstance(seed: Long): LayerModel = + FunctionalLayerModel(this) +} + +/** + * Functional layer model. Holds no weights. + * @param activationFunction activation function + */ +private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction) + extends LayerModel { + val size = 0 + // matrices for in-place computations + // outputs + private var f: BDM[Double] = null + // delta + private var d: BDM[Double] = null + // matrix for error computation + private var e: BDM[Double] = null + // delta gradient + private lazy val dg = new Array[Double](0) + + override def eval(data: BDM[Double]): BDM[Double] = { + if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols) + activationFunction.eval(data, f) + f + } + + override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { + if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols) + activationFunction.derivative(input, d) + d :*= nextDelta + d + } + + override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg + + override def weights(): Vector = Vectors.dense(new Array[Double](0)) + + def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.crossEntropy(output, target, e) + (e, error) + } + + def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) + val error = activationFunction.squared(output, target, e) + (e, error) + } + + def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { + // TODO: allow user pick error + activationFunction match { + case sigmoid: SigmoidFunction => squared(output, target) + case softmax: SoftmaxFunction => crossEntropy(output, target) + } + } +} + +/** + * Fabric of functional layer models + */ +private[ann] object FunctionalLayerModel { + def apply(layer: FunctionalLayer): FunctionalLayerModel = + new FunctionalLayerModel(layer.activationFunction) +} + +/** + * Trait for the artificial neural network (ANN) topology properties + */ +private[ann] trait Topology extends Serializable{ + def getInstance(weights: Vector): TopologyModel + def getInstance(seed: Long): TopologyModel +} + +/** + * Trait for ANN topology model + */ +private[ann] trait TopologyModel extends Serializable{ + /** + * Forward propagation + * @param data input data + * @return array of outputs for each of the layers + */ + def forward(data: BDM[Double]): Array[BDM[Double]] + + /** + * Prediction of the model + * @param data input data + * @return prediction + */ + def predict(data: Vector): Vector + + /** + * Computes gradient for the network + * @param data input data + * @param target target output + * @param cumGradient cumulative gradient + * @param blockSize block size + * @return error + */ + def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, + blockSize: Int): Double + + /** + * Returns the weights of the ANN + * @return weights + */ + def weights(): Vector +} + +/** + * Feed forward ANN + * @param layers + */ +private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { + override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights) + + override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed) +} + +/** + * Factory for some of the frequently-used topologies + */ +private[ml] object FeedForwardTopology { + /** + * Creates a feed forward topology from the array of layers + * @param layers array of layers + * @return feed forward topology + */ + def apply(layers: Array[Layer]): FeedForwardTopology = { + new FeedForwardTopology(layers) + } + + /** + * Creates a multi-layer perceptron + * @param layerSizes sizes of layers including input and output size + * @param softmax wether to use SoftMax or Sigmoid function for an output layer. + * Softmax is default + * @return multilayer perceptron topology + */ + def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = { + val layers = new Array[Layer]((layerSizes.length - 1) * 2) + for(i <- 0 until layerSizes.length - 1){ + layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1)) + layers(i * 2 + 1) = + if (softmax && i == layerSizes.length - 2) { + new FunctionalLayer(new SoftmaxFunction()) + } else { + new FunctionalLayer(new SigmoidFunction()) + } + } + FeedForwardTopology(layers) + } +} + +/** + * Model of Feed Forward Neural Network. + * Implements forward, gradient computation and can return weights in vector format. + * @param layerModels models of layers + * @param topology topology of the network + */ +private[ml] class FeedForwardModel private( + val layerModels: Array[LayerModel], + val topology: FeedForwardTopology) extends TopologyModel { + override def forward(data: BDM[Double]): Array[BDM[Double]] = { + val outputs = new Array[BDM[Double]](layerModels.length) + outputs(0) = layerModels(0).eval(data) + for (i <- 1 until layerModels.length) { + outputs(i) = layerModels(i).eval(outputs(i-1)) + } + outputs + } + + override def computeGradient( + data: BDM[Double], + target: BDM[Double], + cumGradient: Vector, + realBatchSize: Int): Double = { + val outputs = forward(data) + val deltas = new Array[BDM[Double]](layerModels.length) + val L = layerModels.length - 1 + val (newE, newError) = layerModels.last match { + case flm: FunctionalLayerModel => flm.error(outputs.last, target) + case _ => + throw new UnsupportedOperationException("Non-functional layer not supported at the top") + } + deltas(L) = new BDM[Double](0, 0) + deltas(L - 1) = newE + for (i <- (L - 2) to (0, -1)) { + deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1)) + } + val grads = new Array[Array[Double]](layerModels.length) + for (i <- 0 until layerModels.length) { + val input = if (i==0) data else outputs(i - 1) + grads(i) = layerModels(i).grad(deltas(i), input) + } + // update cumGradient + val cumGradientArray = cumGradient.toArray + var offset = 0 + // TODO: extract roll + for (i <- 0 until grads.length) { + val gradArray = grads(i) + var k = 0 + while (k < gradArray.length) { + cumGradientArray(offset + k) += gradArray(k) + k += 1 + } + offset += gradArray.length + } + newError + } + + // TODO: do we really need to copy the weights? they should be read-only + override def weights(): Vector = { + // TODO: extract roll + var size = 0 + for (i <- 0 until layerModels.length) { + size += layerModels(i).size + } + val array = new Array[Double](size) + var offset = 0 + for (i <- 0 until layerModels.length) { + val layerWeights = layerModels(i).weights().toArray + System.arraycopy(layerWeights, 0, array, offset, layerWeights.length) + offset += layerWeights.length + } + Vectors.dense(array) + } + + override def predict(data: Vector): Vector = { + val size = data.size + val result = forward(new BDM[Double](size, 1, data.toArray)) + Vectors.dense(result.last.toArray) + } +} + +/** + * Fabric for feed forward ANN models + */ +private[ann] object FeedForwardModel { + + /** + * Creates a model from a topology and weights + * @param topology topology + * @param weights weights + * @return model + */ + def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for (i <- 0 until layers.length) { + layerModels(i) = layers(i).getInstance(weights, offset) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } + + /** + * Creates a model given a topology and seed + * @param topology topology + * @param seed seed for generating the weights + * @return model + */ + def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + var offset = 0 + for(i <- 0 until layers.length){ + layerModels(i) = layers(i).getInstance(seed) + offset += layerModels(i).size + } + new FeedForwardModel(layerModels, topology) + } +} + +/** + * Neural network gradient. Does nothing but calling Model's gradient + * @param topology topology + * @param dataStacker data stacker + */ +private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient { + + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val gradient = Vectors.zeros(weights.size) + val loss = compute(data, label, weights, gradient) + (gradient, loss) + } + + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val (input, target, realBatchSize) = dataStacker.unstack(data) + val model = topology.getInstance(weights) + model.computeGradient(input, target, cumGradient, realBatchSize) + } +} + +/** + * Stacks pairs of training samples (input, output) in one vector allowing them to pass + * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks + * or matrices of inputs and outputs and then stack them in one vector. + * This can be used for further batch computations after unstacking. + * @param stackSize stack size + * @param inputSize size of the input vectors + * @param outputSize size of the output vectors + */ +private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) + extends Serializable { + + /** + * Stacks the data + * @param data RDD of vector pairs + * @return RDD of double (always zero) and vector that contains the stacked vectors + */ + def stack(data: RDD[(Vector, Vector)]): RDD[(Double, Vector)] = { + val stackedData = if (stackSize == 1) { + data.map { v => + (0.0, + Vectors.fromBreeze(BDV.vertcat( + v._1.toBreeze.toDenseVector, + v._2.toBreeze.toDenseVector)) + ) } + } else { + data.mapPartitions { it => + it.grouped(stackSize).map { seq => + val size = seq.size + val bigVector = new Array[Double](inputSize * size + outputSize * size) + var i = 0 + seq.foreach { case (in, out) => + System.arraycopy(in.toArray, 0, bigVector, i * inputSize, inputSize) + System.arraycopy(out.toArray, 0, bigVector, + inputSize * size + i * outputSize, outputSize) + i += 1 + } + (0.0, Vectors.dense(bigVector)) + } + } + } + stackedData + } + + /** + * Unstack the stacked vectors into matrices for batch operations + * @param data stacked vector + * @return pair of matrices holding input and output data and the real stack size + */ + def unstack(data: Vector): (BDM[Double], BDM[Double], Int) = { + val arrData = data.toArray + val realStackSize = arrData.length / (inputSize + outputSize) + val input = new BDM(inputSize, realStackSize, arrData) + val target = new BDM(outputSize, realStackSize, arrData, inputSize * realStackSize) + (input, target, realStackSize) + } +} + +/** + * Simple updater + */ +private[ann] class ANNUpdater extends Updater { + + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { + val thisIterStepSize = stepSize + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + (Vectors.fromBreeze(brzWeights), 0) + } +} + +/** + * MLlib-style trainer class that trains a network given the data and topology + * @param topology topology of ANN + * @param inputSize input size + * @param outputSize output size + */ +private[ml] class FeedForwardTrainer( + topology: Topology, + val inputSize: Int, + val outputSize: Int) extends Serializable { + + // TODO: what if we need to pass random seed? + private var _weights = topology.getInstance(11L).weights() + private var _stackSize = 128 + private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) + private var _gradient: Gradient = new ANNGradient(topology, dataStacker) + private var _updater: Updater = new ANNUpdater() + private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100) + + /** + * Returns weights + * @return weights + */ + def getWeights: Vector = _weights + + /** + * Sets weights + * @param value weights + * @return trainer + */ + def setWeights(value: Vector): FeedForwardTrainer = { + _weights = value + this + } + + /** + * Sets the stack size + * @param value stack size + * @return trainer + */ + def setStackSize(value: Int): FeedForwardTrainer = { + _stackSize = value + dataStacker = new DataStacker(value, inputSize, outputSize) + this + } + + /** + * Sets the SGD optimizer + * @return SGD optimizer + */ + def SGDOptimizer: GradientDescent = { + val sgd = new GradientDescent(_gradient, _updater) + optimizer = sgd + sgd + } + + /** + * Sets the LBFGS optimizer + * @return LBGS optimizer + */ + def LBFGSOptimizer: LBFGS = { + val lbfgs = new LBFGS(_gradient, _updater) + optimizer = lbfgs + lbfgs + } + + /** + * Sets the updater + * @param value updater + * @return trainer + */ + def setUpdater(value: Updater): FeedForwardTrainer = { + _updater = value + updateUpdater(value) + this + } + + /** + * Sets the gradient + * @param value gradient + * @return trainer + */ + def setGradient(value: Gradient): FeedForwardTrainer = { + _gradient = value + updateGradient(value) + this + } + + private[this] def updateGradient(gradient: Gradient): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setGradient(gradient) + case sgd: GradientDescent => sgd.setGradient(gradient) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + private[this] def updateUpdater(updater: Updater): Unit = { + optimizer match { + case lbfgs: LBFGS => lbfgs.setUpdater(updater) + case sgd: GradientDescent => sgd.setUpdater(updater) + case other => throw new UnsupportedOperationException( + s"Only LBFGS and GradientDescent are supported but got ${other.getClass}.") + } + } + + /** + * Trains the ANN + * @param data RDD of input and output vector pairs + * @return model + */ + def train(data: RDD[(Vector, Vector)]): TopologyModel = { + val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights) + topology.getInstance(newWeights) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala new file mode 100644 index 0000000000000..8cd2103d7d5e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} +import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} +import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} +import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.DataFrame + +/** Params for Multilayer Perceptron. */ +private[ml] trait MultilayerPerceptronParams extends PredictorParams + with HasSeed with HasMaxIter with HasTol { + /** + * Layer sizes including input size and output size. + * @group param + */ + final val layers: IntArrayParam = new IntArrayParam(this, "layers", + "Sizes of layers from input layer to output layer" + + " E.g., Array(780, 100, 10) means 780 inputs, " + + "one hidden layer with 100 neurons and output layer of 10 neurons.", + // TODO: how to check ALSO that all elements are greater than 0? + ParamValidators.arrayLengthGt(1) + ) + + /** @group setParam */ + def setLayers(value: Array[Int]): this.type = set(layers, value) + + /** @group getParam */ + final def getLayers: Array[Int] = $(layers) + + /** + * Block size for stacking input data in matrices to speed up the computation. + * Data is stacked within partitions. If block size is more than remaining data in + * a partition then it is adjusted to the size of this data. + * Recommended size is between 10 and 1000. + * @group expertParam + */ + final val blockSize: IntParam = new IntParam(this, "blockSize", + "Block size for stacking input data in matrices. Data is stacked within partitions." + + " If block size is more than remaining data in a partition then " + + "it is adjusted to the size of this data. Recommended size is between 10 and 1000", + ParamValidators.gt(0)) + + /** @group setParam */ + def setBlockSize(value: Int): this.type = set(blockSize, value) + + /** @group getParam */ + final def getBlockSize: Int = $(blockSize) + + /** + * Set the maximum number of iterations. + * Default is 100. + * @group setParam + */ + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** + * Set the convergence tolerance of iterations. + * Smaller value will lead to higher accuracy with the cost of more iterations. + * Default is 1E-4. + * @group setParam + */ + def setTol(value: Double): this.type = set(tol, value) + + /** + * Set the seed for weights initialization. + * @group setParam + */ + def setSeed(value: Long): this.type = set(seed, value) + + setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) +} + +/** Label to vector converter. */ +private object LabelConverter { + // TODO: Use OneHotEncoder instead + /** + * Encodes a label as a vector. + * Returns a vector of given length with zeroes at all positions + * and value 1.0 at the position that corresponds to the label. + * + * @param labeledPoint labeled point + * @param labelCount total number of labels + * @return pair of features and vector encoding of a label + */ + def encodeLabeledPoint(labeledPoint: LabeledPoint, labelCount: Int): (Vector, Vector) = { + val output = Array.fill(labelCount)(0.0) + output(labeledPoint.label.toInt) = 1.0 + (labeledPoint.features, Vectors.dense(output)) + } + + /** + * Converts a vector to a label. + * Returns the position of the maximal element of a vector. + * + * @param output label encoded with a vector + * @return label + */ + def decodeLabel(output: Vector): Double = { + output.argmax.toDouble + } +} + +/** + * :: Experimental :: + * Classifier trainer based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * Number of inputs has to be equal to the size of feature vectors. + * Number of outputs has to be equal to the total number of labels. + * + */ +@Experimental +class MultilayerPerceptronClassifier(override val uid: String) + extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] + with MultilayerPerceptronParams { + + def this() = this(Identifiable.randomUID("mlpc")) + + override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) + + /** + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @return Fitted model + */ + override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { + val myLayers = $(layers) + val labels = myLayers.last + val lpData = extractLabeledPoints(dataset) + val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) + val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) + val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) + FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) + FeedForwardTrainer.setStackSize($(blockSize)) + val mlpModel = FeedForwardTrainer.train(data) + new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) + } +} + +/** + * :: Experimental :: + * Classifier model based on the Multilayer Perceptron. + * Each layer has sigmoid activation function, output layer has softmax. + * @param uid uid + * @param layers array of layer sizes including input and output layers + * @param weights vector of initial weights for the model that consists of the weights of layers + * @return prediction model + */ +@Experimental +class MultilayerPerceptronClassifierModel private[ml] ( + override val uid: String, + layers: Array[Int], + weights: Vector) + extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] + with Serializable { + + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + + /** + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + override protected def predict(features: Vector): Double = { + LabelConverter.decodeLabel(mlpModel.predict(features)) + } + + override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { + copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 954aa17e26a02..d68f5ff0053c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -166,6 +166,11 @@ object ParamValidators { def inArray[T](allowed: java.util.List[T]): T => Boolean = { (value: T) => allowed.contains(value) } + + /** Check that the array length is greater than lowerBound. */ + def arrayLengthGt[T](lowerBound: Double): Array[T] => Boolean = { (value: Array[T]) => + value.length > lowerBound + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index ab7611fd077ef..8f0d1e4aa010a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} * @param gradient Gradient function to be used. * @param updater Updater to be used to update weights after every iteration. */ -class GradientDescent private[mllib] (private var gradient: Gradient, private var updater: Updater) +class GradientDescent private[spark] (private var gradient: Gradient, private var updater: Updater) extends Optimizer with Logging { private var stepSize: Double = 1.0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala new file mode 100644 index 0000000000000..1292e57d7c01a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + + +class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { + + // TODO: test for weights comparison with Weka MLP + test("ANN with Sigmoid learns XOR function with LBFGS optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array(0.0, 1.0, 1.0, 0.0) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 1) + trainer.setWeights(initialWeights) + trainer.LBFGSOptimizer.setNumIterations(20) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input)(0), label(0)) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(math.round(p) === l) + } + } + + test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") { + val inputs = Array( + Array(0.0, 0.0), + Array(0.0, 1.0), + Array(1.0, 0.0), + Array(1.0, 1.0) + ) + val outputs = Array( + Array(1.0, 0.0), + Array(0.0, 1.0), + Array(0.0, 1.0), + Array(1.0, 0.0) + ) + val data = inputs.zip(outputs).map { case (features, label) => + (Vectors.dense(features), Vectors.dense(label)) + } + val rddData = sc.parallelize(data, 1) + val hiddenLayersTopology = Array(5) + val dataSample = rddData.first() + val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size + val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) + val initialWeights = FeedForwardModel(topology, 23124).weights() + val trainer = new FeedForwardTrainer(topology, 2, 2) + trainer.SGDOptimizer.setNumIterations(2000) + trainer.setWeights(initialWeights) + val model = trainer.train(rddData) + val predictionAndLabels = rddData.map { case (input, label) => + (model.predict(input), label) + }.collect() + predictionAndLabels.foreach { case (p, l) => + assert(p ~== l absTol 0.5) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala new file mode 100644 index 0000000000000..ddc948f65df45 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.sql.Row + +class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("XOR function learning as binary classification problem with two outputs.") { + val dataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0)) + ).toDF("features", "label") + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(100) + val model = trainer.fit(dataFrame) + val result = model.transform(dataFrame) + val predictionAndLabels = result.select("prediction", "label").collect() + predictionAndLabels.foreach { case Row(p: Double, l: Double) => + assert(p == l) + } + } + + // TODO: implement a more rigorous test + test("3 class classification with 2 hidden layers") { + val nPoints = 1000 + + // The following weights are taken from OneVsRestSuite.scala + // they represent 3-class iris dataset + val weights = Array( + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682) + + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val rdd = sc.parallelize(generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 2) + val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") + val numClasses = 3 + val numIterations = 100 + val layers = Array[Int](4, 5, 4, numClasses) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(11L) + .setMaxIter(numIterations) + val model = trainer.fit(dataFrame) + val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label") + .map { case Row(p: Double, l: Double) => (p, l) } + // train multinomial logistic regression + val lr = new LogisticRegressionWithLBFGS() + .setIntercept(true) + .setNumClasses(numClasses) + lr.optimizer.setRegParam(0.0) + .setNumIterations(numIterations) + val lrModel = lr.run(rdd) + val lrPredictionAndLabels = lrModel.predict(rdd.map(_.features)).zip(rdd.map(_.label)) + // MLP's predictions should not differ a lot from LR's. + val lrMetrics = new MulticlassMetrics(lrPredictionAndLabels) + val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) + assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) + } +} From 4011a947154d97a9ffb5a71f077481a12534d36b Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Fri, 31 Jul 2015 11:50:15 -0700 Subject: [PATCH 203/219] [SPARK-9231] [MLLIB] DistributedLDAModel method for top topics per document jira: https://issues.apache.org/jira/browse/SPARK-9231 Helper method in DistributedLDAModel of this form: ``` /** * For each document, return the top k weighted topics for that document. * return RDD of (doc ID, topic indices, topic weights) */ def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] ``` Author: Yuhao Yang Closes #7785 from hhbyyh/topTopicsPerdoc and squashes the following commits: 30ad153 [Yuhao Yang] small fix fd24580 [Yuhao Yang] add topTopics per document to DistributedLDAModel --- .../spark/mllib/clustering/LDAModel.scala | 19 ++++++++++++++++++- .../spark/mllib/clustering/LDASuite.scala | 13 ++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6cfad3fbbdb87..82281a0daf008 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, normalize, sum} +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argtopk, normalize, sum} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -591,6 +591,23 @@ class DistributedLDAModel private[clustering] ( JavaPairRDD.fromRDD(topicDistributions.asInstanceOf[RDD[(java.lang.Long, Vector)]]) } + /** + * For each document, return the top k weighted topics for that document and their weights. + * @return RDD of (doc ID, topic indices, topic weights) + */ + def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { + graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => + val topIndices = argtopk(topicCounts, k) + val sumCounts = sum(topicCounts) + val weights = if (sumCounts != 0) { + topicCounts(topIndices) / sumCounts + } else { + topicCounts(topIndices) + } + (docID.toLong, topIndices.toArray, weights.toArray) + } + } + // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index c43e1e575c09c..695ee3b82efc5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, max, argmax} +import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.Edge @@ -108,6 +108,17 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5) } + val top2TopicsPerDoc = model.topTopicsPerDocument(2).map(t => (t._1, (t._2, t._3))) + model.topicDistributions.join(top2TopicsPerDoc).collect().foreach { + case (docId, (topicDistribution, (indices, weights))) => + assert(indices.length == 2) + assert(weights.length == 2) + val bdvTopicDist = topicDistribution.toBreeze + val top2Indices = argtopk(bdvTopicDist, 2) + assert(top2Indices.toArray === indices) + assert(bdvTopicDist(top2Indices).toArray === weights) + } + // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) From e8bdcdeabb2df139a656f86686cdb53c891b1f4b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 31 Jul 2015 11:56:52 -0700 Subject: [PATCH 204/219] [SPARK-6885] [ML] decision tree support predict class probabilities Decision tree support predict class probabilities. Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593). I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction. Author: Yanbo Liang Closes #7694 from yanboliang/spark-6885 and squashes the following commits: 08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue 2174278 [Yanbo Liang] solve merge conflicts 7e90ba8 [Yanbo Liang] fix typos 33ae183 [Yanbo Liang] fix annotation ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again 6167fb0 [Yanbo Liang] optimize calculateImpurityStats function fbbe2ec [Yanbo Liang] eliminate duplicated struct and code beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode 99e8943 [Yanbo Liang] code optimization 5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats 227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator d746ffc [Yanbo Liang] decision tree support predict class probabilities --- .../DecisionTreeClassifier.scala | 40 ++++-- .../ml/classification/GBTClassifier.scala | 2 +- .../RandomForestClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../scala/org/apache/spark/ml/tree/Node.scala | 80 ++++++----- .../spark/ml/tree/impl/RandomForest.scala | 126 ++++++++---------- .../spark/mllib/tree/impurity/Entropy.scala | 2 +- .../spark/mllib/tree/impurity/Gini.scala | 2 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 2 +- .../tree/model/InformationGainStats.scala | 61 ++++++++- .../DecisionTreeClassifierSuite.scala | 30 ++++- .../classification/GBTClassifierSuite.scala | 2 +- .../RandomForestClassifierSuite.scala | 2 +- 16 files changed, 229 insertions(+), 130 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 36fe1bd40469c..f27cfd0331419 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,12 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} @@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame */ @Experimental final class DecisionTreeClassifier(override val uid: String) - extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("dtc")) @@ -106,8 +105,9 @@ object DecisionTreeClassifier { @Experimental final class DecisionTreeClassificationModel private[ml] ( override val uid: String, - override val rootNode: Node) - extends PredictionModel[Vector, DecisionTreeClassificationModel] + override val rootNode: Node, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { require(rootNode != null, @@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode) + def this(rootNode: Node, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numClasses) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction + } + + override protected def predictRaw(features: Vector): Vector = { + Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone()) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val sum = dv.values.sum + while (i < size) { + dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0 + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) } override def toString: String = { @@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel { s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") - new DecisionTreeClassificationModel(uid, rootNode) + new DecisionTreeClassificationModel(uid, rootNode, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index eb0b1a0a405fc..c3891a9599262 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -190,7 +190,7 @@ final class GBTClassificationModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) if (prediction > 0.0) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bc19bd6df894f..0c7eb4a662fdb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] ( // Ignore the weights since all are 1.0 for now. val votes = new Array[Double](numClasses) _trees.view.foreach { tree => - val prediction = tree.rootNode.predict(features).toInt + val prediction = tree.rootNode.predictImpl(features).prediction.toInt votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } Vectors.dense(votes) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6f3340c2f02be..4d30e4b5548aa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] ( def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index e38dc73ee0ba7..5633bc320273a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -180,7 +180,7 @@ final class GBTRegressionModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 506a878c2553b..17fb1ad5e15d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predict(features)).sum / numTrees + _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees } override def copy(extra: ParamMap): RandomForestRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index bbc2427ca7d3d..8879352a600a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,9 @@ package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict} + Node => OldNode, Predict => OldPredict, ImpurityStats} /** * :: DeveloperApi :: @@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable { /** Impurity measure at this node (for training data) */ def impurity: Double + /** + * Statistics aggregated from training data at this node, used to compute prediction, impurity, + * and probabilities. + * For classification, the array of class counts must be normalized to a probability distribution. + */ + private[tree] def impurityStats: ImpurityCalculator + /** Recursive prediction helper method */ - private[ml] def predict(features: Vector): Double = prediction + private[ml] def predictImpl(features: Vector): LeafNode /** * Get the number of nodes in tree below this node, including leaf nodes. @@ -75,7 +83,8 @@ private[ml] object Node { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) + new LeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain @@ -85,7 +94,7 @@ private[ml] object Node { new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) } } } @@ -99,11 +108,13 @@ private[ml] object Node { @DeveloperApi final class LeafNode private[ml] ( override val prediction: Double, - override val impurity: Double) extends Node { + override val impurity: Double, + override val impurityStats: ImpurityCalculator) extends Node { - override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" + override def toString: String = + s"LeafNode(prediction = $prediction, impurity = $impurity)" - override private[ml] def predict(features: Vector): Double = prediction + override private[ml] def predictImpl(features: Vector): LeafNode = this override private[tree] def numDescendants: Int = 0 @@ -115,9 +126,8 @@ final class LeafNode private[ml] ( override private[tree] def subtreeDepth: Int = 0 override private[ml] def toOld(id: Int): OldNode = { - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, - None, None, None, None) + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), + impurity, isLeaf = true, None, None, None, None) } } @@ -139,17 +149,18 @@ final class InternalNode private[ml] ( val gain: Double, val leftChild: Node, val rightChild: Node, - val split: Split) extends Node { + val split: Split, + override val impurityStats: ImpurityCalculator) extends Node { override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" } - override private[ml] def predict(features: Vector): Double = { + override private[ml] def predictImpl(features: Vector): LeafNode = { if (split.shouldGoLeft(features)) { - leftChild.predict(features) + leftChild.predictImpl(features) } else { - rightChild.predict(features) + rightChild.predictImpl(features) } } @@ -172,9 +183,8 @@ final class InternalNode private[ml] ( override private[ml] def toOld(id: Int): OldNode = { assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + " since the old API does not support deep trees.") - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, - Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, + isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), Some(rightChild.toOld(OldNode.rightChildIndex(id))), Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, new OldPredict(leftChild.prediction, prob = 0.0), @@ -223,36 +233,36 @@ private object InternalNode { * * @param id We currently use the same indexing as the old implementation in * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. - * @param predictionStats Predicted label + class probability (for classification). - * We will later modify this to store aggregate statistics for labels - * to provide all class probabilities (for classification) and maybe a - * distribution (for regression). * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, * so that we do not need to consider splitting it further. - * @param stats Old structure for storing stats about information gain, prediction, etc. - * This is legacy and will be modified in the future. + * @param stats Impurity statistics for this node. */ private[tree] class LearningNode( var id: Int, - var predictionStats: OldPredict, - var impurity: Double, var leftChild: Option[LearningNode], var rightChild: Option[LearningNode], var split: Option[Split], var isLeaf: Boolean, - var stats: Option[OldInformationGainStats]) extends Serializable { + var stats: ImpurityStats) extends Serializable { /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ def toNode: Node = { if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty, + assert(rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(predictionStats.predict, impurity, stats.get.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get) + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) } else { - new LeafNode(predictionStats.predict, impurity) + if (stats.valid) { + new LeafNode(stats.impurityCalculator.predict, stats.impurity, + stats.impurityCalculator) + } else { + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + } + } } @@ -263,16 +273,14 @@ private[tree] object LearningNode { /** Create a node with some of its fields set. */ def apply( id: Int, - predictionStats: OldPredict, - impurity: Double, - isLeaf: Boolean): LearningNode = { - new LearningNode(id, predictionStats, impurity, None, None, None, false, None) + isLeaf: Boolean, + stats: ImpurityStats): LearningNode = { + new LearningNode(id, None, None, None, false, stats) } /** Create an empty node with the given node index. Values must be set later on. */ def emptyNode(nodeIndex: Int): LearningNode = { - new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN, - None, None, None, false, None) + new LearningNode(nodeIndex, None, None, None, false, null) } // The below indexing methods were copied from spark.mllib.tree.model.Node diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 15b56bd844bad..a8b90d9d266a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict} +import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging { parentUID match { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) } case None => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) } @@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging { } // find best split for each node - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats, predict)) + (nodeIndex, (split, stats)) }.collectAsMap() timer.stop("chooseSplits") @@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging { val nodeIndex = node.id val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = nodeToBestSplits(aggNodeIndex) logDebug("best split = " + split) // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) - node.predictionStats = predict node.isLeaf = isLeaf - node.stats = Some(stats) - node.impurity = stats.impurity + node.stats = stats logDebug("Node = " + node) if (!isLeaf) { @@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging { val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), - stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), - stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) if (nodeIdCache.nonEmpty) { val nodeIndexUpdater = NodeIndexUpdater( @@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging { } /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates. + * @param stats the recycle impurity statistics for this feature's all splits, + * only 'impurity' and 'impurityCalculator' are valid between each iteration * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for split + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) */ - private def calculateGainForSplit( + private def calculateImpurityStats( + stats: ImpurityStats, leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata, - impurity: Double): InformationGainStats = { + metadata: DecisionTreeMetadata): ImpurityStats = { + + val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { + leftImpurityCalculator.copy.add(rightImpurityCalculator) + } else { + stats.impurityCalculator + } + + val impurity: Double = if (stats == null) { + parentImpurityCalculator.calculate() + } else { + stats.impurity + } + val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count + val totalCount = leftCount + rightCount + // If left child or right child doesn't satisfy minimum instances per node, // then this split is invalid, return invalid information gain stats. if ((leftCount < metadata.minInstancesPerNode) || (rightCount < metadata.minInstancesPerNode)) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - val totalCount = leftCount + rightCount - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging { // if information gain doesn't satisfy minimum information gain, // then this split is invalid, return invalid information gain stats. if (gain < metadata.minInfoGain) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - // calculate left and right predict - val leftPredict = calculatePredict(leftImpurityCalculator) - val rightPredict = calculatePredict(rightImpurityCalculator) - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, - leftPredict, rightPredict) - } - - private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { - val predict = impurityCalculator.predict - val prob = impurityCalculator.prob(predict) - new Predict(predict, prob) - } - - /** - * Calculate predict value for current node, given stats of any split. - * Note that this function is called only once for each node. - * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a split - * @return predict value and impurity for current node - */ - private def calculatePredictImpurity( - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - val predict = calculatePredict(parentNodeAgg) - val impurity = parentNodeAgg.calculate() - - (predict, impurity) + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) } /** @@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging { binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], - node: LearningNode): (Split, InformationGainStats, Predict) = { + node: LearningNode): (Split, ImpurityStats) = { - // Calculate prediction and impurity if current node is top node + // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) { - None + var gainAndImpurityStats: ImpurityStats = if (level ==0) { + null } else { - Some((node.predictionStats, node.impurity)) + node.stats } // For each (feature, split), calculate the gain, and select the best (feature, split). @@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIdx, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { @@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging { val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { @@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) @@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging { } }.maxBy(_._2.gain) - (bestSplit, bestSplitStats, predictionAndImpurity.get._1) + (bestSplit, bestSplitStats) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 5ac10f3fd32dd..0768204c33914 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 19d318203c344..d0077db6832e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 578749d85a4e6..86cee7e430b0a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) { +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 7104a7fa4dd4c..04d0cd24e6632 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -98,7 +98,7 @@ private[tree] class VarianceAggregator() * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { require(stats.size == 3, s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index dc9e0f9f51ffb..508bf9c1bdb47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator /** * :: DeveloperApi :: @@ -66,7 +67,6 @@ class InformationGainStats( } } - private[spark] object InformationGainStats { /** * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to @@ -76,3 +76,62 @@ private[spark] object InformationGainStats { val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } + +/** + * :: DeveloperApi :: + * Impurity statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param impurityCalculator impurity statistics for current node + * @param leftImpurityCalculator impurity statistics for left child node + * @param rightImpurityCalculator impurity statistics for right child node + * @param valid whether the current split satisfies minimum info gain or + * minimum number of instances per node + */ +@DeveloperApi +private[spark] class ImpurityStats( + val gain: Double, + val impurity: Double, + val impurityCalculator: ImpurityCalculator, + val leftImpurityCalculator: ImpurityCalculator, + val rightImpurityCalculator: ImpurityCalculator, + val valid: Boolean = true) extends Serializable { + + override def toString: String = { + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" + } + + def leftImpurity: Double = if (leftImpurityCalculator != null) { + leftImpurityCalculator.calculate() + } else { + -1.0 + } + + def rightImpurity: Double = if (rightImpurityCalculator != null) { + rightImpurityCalculator.calculate() + } else { + -1.0 + } +} + +private[spark] object ImpurityStats { + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ + def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), + impurityCalculator, null, null, false) + } + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object + * that only 'impurity' and 'impurityCalculator' are defined. + */ + def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 73b4805c4c597..c7bbf1ce07a23 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2) ParamsSuite.checkParams(model) } @@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) } + test("predictRaw and predictProbability") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + val predictions = newTree.transform(newData) + .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index a7bc77965fefd..d4b5896c12c06 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))), Array(1.0)) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ab711c8e4b215..dbb2577c6204d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2) ParamsSuite.checkParams(model) } From 0a1d2ca42c8b31d6b0e70163795f0185d4622f87 Mon Sep 17 00:00:00 2001 From: Iulian Dragos Date: Fri, 31 Jul 2015 12:04:03 -0700 Subject: [PATCH 205/219] [SPARK-8979] Add a PID based rate estimator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Based on #7600 /cc tdas Author: Iulian Dragos Author: François Garillot Closes #7648 from dragos/topic/streaming-bp/pid and squashes the following commits: aa5b097 [Iulian Dragos] Add more comments, made all PID constant parameters positive, a couple more tests. 93b74f8 [Iulian Dragos] Better explanation of historicalError. 7975b0c [Iulian Dragos] Add configuration for PID. 26cfd78 [Iulian Dragos] A couple of variable renames. d0bdf7c [Iulian Dragos] Update to latest version of the code, various style and name improvements. d58b845 [François Garillot] [SPARK-8979][Streaming] Implements a PIDRateEstimator --- .../dstream/ReceiverInputDStream.scala | 2 +- .../scheduler/rate/PIDRateEstimator.scala | 124 ++++++++++++++++ .../scheduler/rate/RateEstimator.scala | 18 ++- .../rate/PIDRateEstimatorSuite.scala | 137 ++++++++++++++++++ 4 files changed, 276 insertions(+), 5 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 646a8c3530a62..670ef8d296a0b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -46,7 +46,7 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont */ override protected[streaming] val rateController: Option[RateController] = { if (RateController.isBackPressureEnabled(ssc.conf)) { - RateEstimator.create(ssc.conf).map { new ReceiverRateController(id, _) } + Some(new ReceiverRateController(id, RateEstimator.create(ssc.conf, ssc.graph.batchDuration))) } else { None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala new file mode 100644 index 0000000000000..6ae56a68ad88c --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimator.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +/** + * Implements a proportional-integral-derivative (PID) controller which acts on + * the speed of ingestion of elements into Spark Streaming. A PID controller works + * by calculating an '''error''' between a measured output and a desired value. In the + * case of Spark Streaming the error is the difference between the measured processing + * rate (number of elements/processing delay) and the previous rate. + * + * @see https://en.wikipedia.org/wiki/PID_controller + * + * @param batchDurationMillis the batch duration, in milliseconds + * @param proportional how much the correction should depend on the current + * error. This term usually provides the bulk of correction and should be positive or zero. + * A value too large would make the controller overshoot the setpoint, while a small value + * would make the controller too insensitive. The default value is 1. + * @param integral how much the correction should depend on the accumulation + * of past errors. This value should be positive or 0. This term accelerates the movement + * towards the desired value, but a large value may lead to overshooting. The default value + * is 0.2. + * @param derivative how much the correction should depend on a prediction + * of future errors, based on current rate of change. This value should be positive or 0. + * This term is not used very often, as it impacts stability of the system. The default + * value is 0. + */ +private[streaming] class PIDRateEstimator( + batchIntervalMillis: Long, + proportional: Double = 1D, + integral: Double = .2D, + derivative: Double = 0D) + extends RateEstimator { + + private var firstRun: Boolean = true + private var latestTime: Long = -1L + private var latestRate: Double = -1D + private var latestError: Double = -1L + + require( + batchIntervalMillis > 0, + s"Specified batch interval $batchIntervalMillis in PIDRateEstimator is invalid.") + require( + proportional >= 0, + s"Proportional term $proportional in PIDRateEstimator should be >= 0.") + require( + integral >= 0, + s"Integral term $integral in PIDRateEstimator should be >= 0.") + require( + derivative >= 0, + s"Derivative term $derivative in PIDRateEstimator should be >= 0.") + + + def compute(time: Long, // in milliseconds + numElements: Long, + processingDelay: Long, // in milliseconds + schedulingDelay: Long // in milliseconds + ): Option[Double] = { + + this.synchronized { + if (time > latestTime && processingDelay > 0 && batchIntervalMillis > 0) { + + // in seconds, should be close to batchDuration + val delaySinceUpdate = (time - latestTime).toDouble / 1000 + + // in elements/second + val processingRate = numElements.toDouble / processingDelay * 1000 + + // In our system `error` is the difference between the desired rate and the measured rate + // based on the latest batch information. We consider the desired rate to be latest rate, + // which is what this estimator calculated for the previous batch. + // in elements/second + val error = latestRate - processingRate + + // The error integral, based on schedulingDelay as an indicator for accumulated errors. + // A scheduling delay s corresponds to s * processingRate overflowing elements. Those + // are elements that couldn't be processed in previous batches, leading to this delay. + // In the following, we assume the processingRate didn't change too much. + // From the number of overflowing elements we can calculate the rate at which they would be + // processed by dividing it by the batch interval. This rate is our "historical" error, + // or integral part, since if we subtracted this rate from the previous "calculated rate", + // there wouldn't have been any overflowing elements, and the scheduling delay would have + // been zero. + // (in elements/second) + val historicalError = schedulingDelay.toDouble * processingRate / batchIntervalMillis + + // in elements/(second ^ 2) + val dError = (error - latestError) / delaySinceUpdate + + val newRate = (latestRate - proportional * error - + integral * historicalError - + derivative * dError).max(0.0) + latestTime = time + if (firstRun) { + latestRate = processingRate + latestError = 0D + firstRun = false + + None + } else { + latestRate = newRate + latestError = error + + Some(newRate) + } + } else None + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index a08685119e5d5..17ccebc1ed41b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -19,6 +19,7 @@ package org.apache.spark.streaming.scheduler.rate import org.apache.spark.SparkConf import org.apache.spark.SparkException +import org.apache.spark.streaming.Duration /** * A component that estimates the rate at wich an InputDStream should ingest @@ -48,12 +49,21 @@ object RateEstimator { /** * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. * - * @return None if there is no configured estimator, otherwise an instance of RateEstimator + * The only known estimator right now is `pid`. + * + * @return An instance of RateEstimator * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any * known estimators. */ - def create(conf: SparkConf): Option[RateEstimator] = - conf.getOption("spark.streaming.backpressure.rateEstimator").map { estimator => - throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + def create(conf: SparkConf, batchInterval: Duration): RateEstimator = + conf.get("spark.streaming.backpressure.rateEstimator", "pid") match { + case "pid" => + val proportional = conf.getDouble("spark.streaming.backpressure.pid.proportional", 1.0) + val integral = conf.getDouble("spark.streaming.backpressure.pid.integral", 0.2) + val derived = conf.getDouble("spark.streaming.backpressure.pid.derived", 0.0) + new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived) + + case estimator => + throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala new file mode 100644 index 0000000000000..97c32d8f2d59e --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/rate/PIDRateEstimatorSuite.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler.rate + +import scala.util.Random + +import org.scalatest.Inspectors.forAll +import org.scalatest.Matchers + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.streaming.Seconds + +class PIDRateEstimatorSuite extends SparkFunSuite with Matchers { + + test("the right estimator is created") { + val conf = new SparkConf + conf.set("spark.streaming.backpressure.rateEstimator", "pid") + val pid = RateEstimator.create(conf, Seconds(1)) + pid.getClass should equal(classOf[PIDRateEstimator]) + } + + test("estimator checks ranges") { + intercept[IllegalArgumentException] { + new PIDRateEstimator(0, 1, 2, 3) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, -1, 2, 3) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, -1, 3) + } + intercept[IllegalArgumentException] { + new PIDRateEstimator(100, 0, 0, -1) + } + } + + private def createDefaultEstimator: PIDRateEstimator = { + new PIDRateEstimator(20, 1D, 0D, 0D) + } + + test("first bound is None") { + val p = createDefaultEstimator + p.compute(0, 10, 10, 0) should equal(None) + } + + test("second bound is rate") { + val p = createDefaultEstimator + p.compute(0, 10, 10, 0) + // 1000 elements / s + p.compute(10, 10, 10, 0) should equal(Some(1000)) + } + + test("works even with no time between updates") { + val p = createDefaultEstimator + p.compute(0, 10, 10, 0) + p.compute(10, 10, 10, 0) + p.compute(10, 10, 10, 0) should equal(None) + } + + test("bound is never negative") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D) + // prepare a series of batch updates, one every 20ms, 0 processed elements, 2ms of processing + // this might point the estimator to try and decrease the bound, but we test it never + // goes below zero, which would be nonsensical. + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.fill(50)(0) // no processing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(100) // strictly positive accumulation + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.fill(49)(Some(0D))) + } + + test("with no accumulated or positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D) + // prepare a series of batch updates, one every 20ms with an increasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => x * 20) // increasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some(x * 1000D)).tail) + } + + test("with no accumulated but some positive error, |I| > 0, follow the processing speed") { + val p = new PIDRateEstimator(20, 1D, 1D, 0D) + // prepare a series of batch updates, one every 20ms with an decreasing number of processed + // elements in each batch, but constant processing time, and no accumulated error. Even though + // the integral part is non-zero, the estimated rate should follow only the proportional term, + // asking for less and less elements + val times = List.tabulate(50)(x => x * 20) // every 20ms + val elements = List.tabulate(50)(x => (50 - x) * 20) // decreasing + val proc = List.fill(50)(20) // 20ms of processing + val sched = List.fill(50)(0) + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + res.tail should equal(List.tabulate(50)(x => Some((50 - x) * 1000D)).tail) + } + + test("with some accumulated and some positive error, |I| > 0, stay below the processing speed") { + val p = new PIDRateEstimator(20, 1D, .01D, 0D) + val times = List.tabulate(50)(x => x * 20) // every 20ms + val rng = new Random() + val elements = List.tabulate(50)(x => rng.nextInt(1000)) + val procDelayMs = 20 + val proc = List.fill(50)(procDelayMs) // 20ms of processing + val sched = List.tabulate(50)(x => rng.nextInt(19)) // random wait + val speeds = elements map ((x) => x.toDouble / procDelayMs * 1000) + + val res = for (i <- List.range(0, 50)) yield p.compute(times(i), elements(i), proc(i), sched(i)) + res.head should equal(None) + forAll(List.range(1, 50)) { (n) => + res(n) should not be None + if (res(n).get > 0 && sched(n) > 0) { + res(n).get should be < speeds(n) + } + } + } +} From 39ab199a3f735b7658ab3331d3e2fb03441aec13 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 31 Jul 2015 12:07:18 -0700 Subject: [PATCH 206/219] [SPARK-8640] [SQL] Enable Processing of Multiple Window Frames in a Single Window Operator This PR enables the processing of multiple window frames in a single window operator. This should improve the performance of processing multiple window expressions wich share partition by/order by clauses, because it will be more efficient with respect to memory use and group processing. Author: Herman van Hovell Closes #7515 from hvanhovell/SPARK-8640 and squashes the following commits: f0e1c21 [Herman van Hovell] Changed Window Logical/Physical plans to use partition by/order by specs directly instead of using WindowSpec. e1711c2 [Herman van Hovell] Enabled the processing of multiple window frames in a single Window operator. --- .../sql/catalyst/analysis/Analyzer.scala | 12 +++++++----- .../plans/logical/basicOperators.scala | 3 ++- .../spark/sql/execution/SparkStrategies.scala | 5 +++-- .../apache/spark/sql/execution/Window.scala | 19 ++++++++++--------- .../sql/hive/execution/HivePlanTest.scala | 18 ++++++++++++++++++ 5 files changed, 40 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 265f3d1e41765..51d910b258647 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -347,7 +347,7 @@ class Analyzer( val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - case oldVersion @ Window(_, windowExpressions, _, child) + case oldVersion @ Window(_, windowExpressions, _, _, child) if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) @@ -825,7 +825,7 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - // Second, we group extractedWindowExprBuffer based on their Window Spec. + // Second, we group extractedWindowExprBuffer based on their Partition and Order Specs. val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr => val distinctWindowSpec = expr.collect { case window: WindowExpression => window.windowSpec @@ -841,7 +841,8 @@ class Analyzer( failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." + s"Please file a bug report with this error message, stack trace, and the query.") } else { - distinctWindowSpec.head + val spec = distinctWindowSpec.head + (spec.partitionSpec, spec.orderSpec) } }.toSeq @@ -850,9 +851,10 @@ class Analyzer( var currentChild = child var i = 0 while (i < groupedWindowExpressions.size) { - val (windowSpec, windowExpressions) = groupedWindowExpressions(i) + val ((partitionSpec, orderSpec), windowExpressions) = groupedWindowExpressions(i) // Set currentChild to the newly created Window operator. - currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild) + currentChild = Window(currentChild.output, windowExpressions, + partitionSpec, orderSpec, currentChild) // Move to next Window Spec. i += 1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a67f8de6b733a..aacfc86ab0e49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -228,7 +228,8 @@ case class Aggregate( case class Window( projectList: Seq[Attribute], windowExpressions: Seq[NamedExpression], - windowSpec: WindowSpecDefinition, + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 03d24a88d4ecd..4aff52d992e6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -389,8 +389,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil } } - case logical.Window(projectList, windowExpressions, spec, child) => - execution.Window(projectList, windowExpressions, spec, planLater(child)) :: Nil + case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => + execution.Window( + projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil case logical.Sample(lb, ub, withReplacement, seed, child) => execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 91c8a02e2b5bc..fe9f2c7028171 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -80,23 +80,24 @@ import scala.collection.mutable case class Window( projectList: Seq[Attribute], windowExpression: Seq[NamedExpression], - windowSpec: WindowSpecDefinition, + partitionSpec: Seq[Expression], + orderSpec: Seq[SortOrder], child: SparkPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList ++ windowExpression.map(_.toAttribute) override def requiredChildDistribution: Seq[Distribution] = { - if (windowSpec.partitionSpec.isEmpty) { + if (partitionSpec.isEmpty) { // Only show warning when the number of bytes is larger than 100 MB? logWarning("No Partition Defined for Window operation! Moving all data to a single " + "partition, this can cause serious performance degradation.") AllTuples :: Nil - } else ClusteredDistribution(windowSpec.partitionSpec) :: Nil + } else ClusteredDistribution(partitionSpec) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(windowSpec.partitionSpec.map(SortOrder(_, Ascending)) ++ windowSpec.orderSpec) + Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -115,12 +116,12 @@ case class Window( case RangeFrame => val (exprs, current, bound) = if (offset == 0) { // Use the entire order expression when the offset is 0. - val exprs = windowSpec.orderSpec.map(_.child) + val exprs = orderSpec.map(_.child) val projection = newMutableProjection(exprs, child.output) - (windowSpec.orderSpec, projection(), projection()) - } else if (windowSpec.orderSpec.size == 1) { + (orderSpec, projection(), projection()) + } else if (orderSpec.size == 1) { // Use only the first order expression when the offset is non-null. - val sortExpr = windowSpec.orderSpec.head + val sortExpr = orderSpec.head val expr = sortExpr.child // Create the projection which returns the current 'value'. val current = newMutableProjection(expr :: Nil, child.output)() @@ -250,7 +251,7 @@ case class Window( // Get all relevant projections. val result = createResultProjection(unboundExpressions) - val grouping = newProjection(windowSpec.partitionSpec, child.output) + val grouping = newProjection(partitionSpec, child.output) // Manage the stream and the grouping. var nextRow: InternalRow = EmptyRow diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index bdb53ddf59c19..ba56a8a6b689c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.functions._ import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.TestHive class HivePlanTest extends QueryTest { @@ -31,4 +34,19 @@ class HivePlanTest extends QueryTest { comparePlans(optimized, correctAnswer) } + + test("window expressions sharing the same partition by and order by clause") { + val df = Seq.empty[(Int, String, Int, Int)].toDF("id", "grp", "seq", "val") + val window = Window. + partitionBy($"grp"). + orderBy($"val") + val query = df.select( + $"id", + sum($"val").over(window.rowsBetween(-1, 1)), + sum($"val").over(window.rangeBetween(-1, 1)) + ) + val plan = query.queryExecution.analyzed + assert(plan.collect{ case w: logical.Window => w }.size === 1, + "Should have only 1 Window operator.") + } } From 3afc1de89cb4de9f8ea74003dd1e6b5b006d06f0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 31 Jul 2015 12:09:48 -0700 Subject: [PATCH 207/219] [SPARK-8564] [STREAMING] Add the Python API for Kinesis This PR adds the Python API for Kinesis, including a Python example and a simple unit test. Author: zsxwing Closes #6955 from zsxwing/kinesis-python and squashes the following commits: e42e471 [zsxwing] Merge branch 'master' into kinesis-python 455f7ea [zsxwing] Remove streaming_kinesis_asl_assembly module and simply add the source folder to streaming_kinesis_asl module 32e6451 [zsxwing] Merge remote-tracking branch 'origin/master' into kinesis-python 5082d28 [zsxwing] Fix the syntax error for Python 2.6 fca416b [zsxwing] Fix wrong comparison 96670ff [zsxwing] Fix the compilation error after merging master 756a128 [zsxwing] Merge branch 'master' into kinesis-python 6c37395 [zsxwing] Print stack trace for debug 7c5cfb0 [zsxwing] RUN_KINESIS_TESTS -> ENABLE_KINESIS_TESTS cc9d071 [zsxwing] Fix the python test errors 466b425 [zsxwing] Add python tests for Kinesis e33d505 [zsxwing] Merge remote-tracking branch 'origin/master' into kinesis-python 3da2601 [zsxwing] Fix the kinesis folder 687446b [zsxwing] Fix the error message and the maven output path add2beb [zsxwing] Merge branch 'master' into kinesis-python 4957c0b [zsxwing] Add the Python API for Kinesis --- dev/run-tests.py | 3 +- dev/sparktestsupport/modules.py | 9 +- docs/streaming-kinesis-integration.md | 19 +++ extras/kinesis-asl-assembly/pom.xml | 103 ++++++++++++++++ .../streaming/kinesis_wordcount_asl.py | 81 +++++++++++++ .../streaming/kinesis/KinesisTestUtils.scala | 19 ++- .../streaming/kinesis/KinesisUtils.scala | 78 +++++++++--- pom.xml | 1 + project/SparkBuild.scala | 6 +- python/pyspark/streaming/kinesis.py | 112 ++++++++++++++++++ python/pyspark/streaming/tests.py | 86 +++++++++++++- 11 files changed, 492 insertions(+), 25 deletions(-) create mode 100644 extras/kinesis-asl-assembly/pom.xml create mode 100644 extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py create mode 100644 python/pyspark/streaming/kinesis.py diff --git a/dev/run-tests.py b/dev/run-tests.py index 29420da9aa956..b6d181418f027 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -301,7 +301,8 @@ def build_spark_sbt(hadoop_version): sbt_goals = ["package", "assembly/assembly", "streaming-kafka-assembly/assembly", - "streaming-flume-assembly/assembly"] + "streaming-flume-assembly/assembly", + "streaming-kinesis-asl-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 44600cb9523c1..956dc81b62e93 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -138,6 +138,7 @@ def contains_file(self, filename): dependencies=[], source_file_regexes=[ "extras/kinesis-asl/", + "extras/kinesis-asl-assembly/", ], build_profile_flags=[ "-Pkinesis-asl", @@ -300,7 +301,13 @@ def contains_file(self, filename): pyspark_streaming = Module( name="pyspark-streaming", - dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly], + dependencies=[ + pyspark_core, + streaming, + streaming_kafka, + streaming_flume_assembly, + streaming_kinesis_asl + ], source_file_regexes=[ "python/pyspark/streaming" ], diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index aa9749afbc867..a7bcaec6fcd84 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -51,6 +51,17 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + +
+ from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + + kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) + + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example. +
@@ -135,6 +146,14 @@ To run the example, bin/run-example streaming.JavaKinesisWordCountASL [Kinesis app name] [Kinesis stream name] [endpoint URL] + +
+ + bin/spark-submit --jars extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + [Kinesis app name] [Kinesis stream name] [endpoint URL] [region name] +
diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml new file mode 100644 index 0000000000000..70d2c9c58f54e --- /dev/null +++ b/extras/kinesis-asl-assembly/pom.xml @@ -0,0 +1,103 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-kinesis-asl-assembly_2.10 + jar + Spark Project Kinesis Assembly + http://spark.apache.org/ + + + streaming-kinesis-asl-assembly + + + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-kinesis-asl-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py new file mode 100644 index 0000000000000..f428f64da3c42 --- /dev/null +++ b/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Consumes messages from a Amazon Kinesis streams and does wordcount. + + This example spins up 1 Kinesis Receiver per shard for the given stream. + It then starts pulling from the last checkpointed sequence number of the given stream. + + Usage: kinesis_wordcount_asl.py + is the name of the consumer app, used to track the read data in DynamoDB + name of the Kinesis stream (ie. mySparkStream) + endpoint of the Kinesis service + (e.g. https://kinesis.us-east-1.amazonaws.com) + + + Example: + # export AWS keys if necessary + $ export AWS_ACCESS_KEY_ID= + $ export AWS_SECRET_KEY= + + # run the example + $ bin/spark-submit -jar extras/kinesis-asl/target/scala-*/\ + spark-streaming-kinesis-asl-assembly_*.jar \ + extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py \ + myAppName mySparkStream https://kinesis.us-east-1.amazonaws.com + + There is a companion helper class called KinesisWordProducerASL which puts dummy data + onto the Kinesis stream. + + This code uses the DefaultAWSCredentialsProviderChain to find credentials + in the following order: + Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + Java System Properties - aws.accessKeyId and aws.secretKey + Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + Instance profile credentials - delivered through the Amazon EC2 metadata service + For more information, see + http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html + + See http://spark.apache.org/docs/latest/streaming-kinesis-integration.html for more details on + the Kinesis Spark Streaming integration. +""" +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream + +if __name__ == "__main__": + if len(sys.argv) != 5: + print( + "Usage: kinesis_wordcount_asl.py ", + file=sys.stderr) + sys.exit(-1) + + sc = SparkContext(appName="PythonStreamingKinesisWordCountAsl") + ssc = StreamingContext(sc, 1) + appName, streamName, endpointUrl, regionName = sys.argv[1:] + lines = KinesisUtils.createStream( + ssc, appName, streamName, endpointUrl, regionName, InitialPositionInStream.LATEST, 2) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index ca39358b75cb6..255ac27f793ba 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -36,9 +36,15 @@ import org.apache.spark.Logging /** * Shared utility methods for performing Kinesis tests that actually transfer data */ -private class KinesisTestUtils( - val endpointUrl: String = "https://kinesis.us-west-2.amazonaws.com", - _regionName: String = "") extends Logging { +private class KinesisTestUtils(val endpointUrl: String, _regionName: String) extends Logging { + + def this() { + this("https://kinesis.us-west-2.amazonaws.com", "") + } + + def this(endpointUrl: String) { + this(endpointUrl, "") + } val regionName = if (_regionName.length == 0) { RegionUtils.getRegionByEndpoint(endpointUrl).getName() @@ -117,6 +123,13 @@ private class KinesisTestUtils( shardIdToSeqNumbers.toMap } + /** + * Expose a Python friendly API. + */ + def pushData(testData: java.util.List[Int]): Unit = { + pushData(scala.collection.JavaConversions.asScalaBuffer(testData)) + } + def deleteStream(): Unit = { try { if (streamCreated) { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index e5acab50181e1..7dab17eba8483 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -86,19 +86,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( ssc: StreamingContext, @@ -130,7 +130,7 @@ object KinesisUtils { * - The Kinesis application name used by the Kinesis Client Library (KCL) will be the app name in * [[org.apache.spark.SparkConf]]. * - * @param ssc Java StreamingContext object + * @param ssc StreamingContext object * @param streamName Kinesis stream name * @param endpointUrl Endpoint url of Kinesis service * (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -175,15 +175,15 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. */ @@ -206,8 +206,8 @@ object KinesisUtils { * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library @@ -216,19 +216,19 @@ object KinesisUtils { * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param regionName Name of region used by the Kinesis Client Library (KCL) to update * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) - * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) - * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) - * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. - * See the Kinesis Spark Streaming documentation for more - * details on the different types of checkpoints. * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the * worker's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * (InitialPositionInStream.TRIM_HORIZON) or * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) */ def createStream( jssc: JavaStreamingContext, @@ -297,3 +297,49 @@ object KinesisUtils { } } } + +/** + * This is a helper class that wraps the methods in KinesisUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's KinesisUtils. + */ +private class KinesisUtilsPythonHelper { + + def getInitialPositionInStream(initialPositionInStream: Int): InitialPositionInStream = { + initialPositionInStream match { + case 0 => InitialPositionInStream.LATEST + case 1 => InitialPositionInStream.TRIM_HORIZON + case _ => throw new IllegalArgumentException( + "Illegal InitialPositionInStream. Please use " + + "InitialPositionInStream.LATEST or InitialPositionInStream.TRIM_HORIZON") + } + } + + def createStream( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: Int, + checkpointInterval: Duration, + storageLevel: StorageLevel, + awsAccessKeyId: String, + awsSecretKey: String + ): JavaReceiverInputDStream[Array[Byte]] = { + if (awsAccessKeyId == null && awsSecretKey != null) { + throw new IllegalArgumentException("awsSecretKey is set but awsAccessKeyId is null") + } + if (awsAccessKeyId != null && awsSecretKey == null) { + throw new IllegalArgumentException("awsAccessKeyId is set but awsSecretKey is null") + } + if (awsAccessKeyId == null && awsSecretKey == null) { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel) + } else { + KinesisUtils.createStream(jssc, kinesisAppName, streamName, endpointUrl, regionName, + getInitialPositionInStream(initialPositionInStream), checkpointInterval, storageLevel, + awsAccessKeyId, awsSecretKey) + } + } + +} diff --git a/pom.xml b/pom.xml index 35fc8c44bc1b0..e351c7c19df96 100644 --- a/pom.xml +++ b/pom.xml @@ -1642,6 +1642,7 @@ kinesis-asl extras/kinesis-asl + extras/kinesis-asl-assembly diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 61a05d375d99e..9a33baa7c6ce1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -382,7 +382,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly") || mName.contains("streaming-kinesis-asl-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { diff --git a/python/pyspark/streaming/kinesis.py b/python/pyspark/streaming/kinesis.py new file mode 100644 index 0000000000000..bcfe2703fecf9 --- /dev/null +++ b/python/pyspark/streaming/kinesis.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_gateway import Py4JJavaError + +from pyspark.serializers import PairDeserializer, NoOpSerializer +from pyspark.storagelevel import StorageLevel +from pyspark.streaming import DStream + +__all__ = ['KinesisUtils', 'InitialPositionInStream', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') + + +class KinesisUtils(object): + + @staticmethod + def createStream(ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, + awsAccessKeyId=None, awsSecretKey=None, decoder=utf8_decoder): + """ + Create an input stream that pulls messages from a Kinesis stream. This uses the + Kinesis Client Library (KCL) to pull messages from Kinesis. + + Note: The given AWS credentials will get saved in DStream checkpoints if checkpointing is + enabled. Make sure that your checkpoint directory is secure. + + :param ssc: StreamingContext object + :param kinesisAppName: Kinesis application name used by the Kinesis Client Library (KCL) to + update DynamoDB + :param streamName: Kinesis stream name + :param endpointUrl: Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + :param regionName: Name of region used by the Kinesis Client Library (KCL) to update + DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + :param initialPositionInStream: In the absence of Kinesis checkpoint info, this is the + worker's initial starting position in the stream. The + values are either the beginning of the stream per Kinesis' + limit of 24 hours (InitialPositionInStream.TRIM_HORIZON) or + the tip of the stream (InitialPositionInStream.LATEST). + :param checkpointInterval: Checkpoint interval for Kinesis checkpointing. See the Kinesis + Spark Streaming documentation for more details on the different + types of checkpoints. + :param storageLevel: Storage level to use for storing the received objects (default is + StorageLevel.MEMORY_AND_DISK_2) + :param awsAccessKeyId: AWS AccessKeyId (default is None. If None, will use + DefaultAWSCredentialsProviderChain) + :param awsSecretKey: AWS SecretKey (default is None. If None, will use + DefaultAWSCredentialsProviderChain) + :param decoder: A function used to decode value (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + jduration = ssc._jduration(checkpointInterval) + + try: + # Use KinesisUtilsPythonHelper to access Scala's KinesisUtils + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.kinesis.KinesisUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, kinesisAppName, streamName, endpointUrl, + regionName, initialPositionInStream, jduration, jlevel, + awsAccessKeyId, awsSecretKey) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + KinesisUtils._printErrorMsg(ssc.sparkContext) + raise e + stream = DStream(jstream, ssc, NoOpSerializer()) + return stream.map(lambda v: decoder(v)) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Kinesis libraries not found in class path. Try one of the following. + + 1. Include the Kinesis library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-kinesis-asl:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-kinesis-asl-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) + + +class InitialPositionInStream(object): + LATEST, TRIM_HORIZON = (0, 1) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 4ecae1e4bf282..5cd544b2144ef 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -36,9 +36,11 @@ import unittest from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.storagelevel import StorageLevel from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition from pyspark.streaming.flume import FlumeUtils +from pyspark.streaming.kinesis import KinesisUtils, InitialPositionInStream class PySparkStreamingTestCase(unittest.TestCase): @@ -891,6 +893,67 @@ def test_flume_polling_multiple_hosts(self): self._testMultipleTimes(self._testFlumePollingMultipleHosts) +class KinesisStreamTests(PySparkStreamingTestCase): + + def test_kinesis_stream_api(self): + # Don't start the StreamingContext because we cannot test it in Jenkins + kinesisStream1 = KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2) + kinesisStream2 = KinesisUtils.createStream( + self.ssc, "myAppNam", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", + InitialPositionInStream.LATEST, 2, StorageLevel.MEMORY_AND_DISK_2, + "awsAccessKey", "awsSecretKey") + + def test_kinesis_stream(self): + if os.environ.get('ENABLE_KINESIS_TESTS') != '1': + print("Skip test_kinesis_stream") + return + + import random + kinesisAppName = ("KinesisStreamTests-%d" % abs(random.randint(0, 10000000))) + kinesisTestUtilsClz = \ + self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.kinesis.KinesisTestUtils") + kinesisTestUtils = kinesisTestUtilsClz.newInstance() + try: + kinesisTestUtils.createStream() + aWSCredentials = kinesisTestUtils.getAWSCredentials() + stream = KinesisUtils.createStream( + self.ssc, kinesisAppName, kinesisTestUtils.streamName(), + kinesisTestUtils.endpointUrl(), kinesisTestUtils.regionName(), + InitialPositionInStream.LATEST, 10, StorageLevel.MEMORY_ONLY, + aWSCredentials.getAWSAccessKeyId(), aWSCredentials.getAWSSecretKey()) + + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + stream.foreachRDD(get_output) + self.ssc.start() + + testData = [i for i in range(1, 11)] + expectedOutput = set([str(i) for i in testData]) + start_time = time.time() + while time.time() - start_time < 120: + kinesisTestUtils.pushData(testData) + if expectedOutput == set(outputBuffer): + break + time.sleep(10) + self.assertEqual(expectedOutput, set(outputBuffer)) + except: + import traceback + traceback.print_exc() + raise + finally: + kinesisTestUtils.deleteStream() + kinesisTestUtils.deleteDynamoDBTable(kinesisAppName) + + def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") @@ -926,10 +989,31 @@ def search_flume_assembly_jar(): else: return jars[0] + +def search_kinesis_asl_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + kinesis_asl_assembly_dir = os.path.join(SPARK_HOME, "extras/kinesis-asl-assembly") + jars = glob.glob( + os.path.join(kinesis_asl_assembly_dir, + "target/scala-*/spark-streaming-kinesis-asl-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming Kinesis ASL assembly jar in %s. " % + kinesis_asl_assembly_dir) + "You need to build Spark with " + "'build/sbt -Pkinesis-asl assembly/assembly streaming-kinesis-asl-assembly/assembly' " + "or 'build/mvn -Pkinesis-asl package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kinesis ASL assembly JARs in %s; please " + "remove all but one") % kinesis_asl_assembly_dir) + else: + return jars[0] + + if __name__ == "__main__": kafka_assembly_jar = search_kafka_assembly_jar() flume_assembly_jar = search_flume_assembly_jar() - jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) + kinesis_asl_assembly_jar = search_kinesis_asl_assembly_jar() + jars = "%s,%s,%s" % (kafka_assembly_jar, flume_assembly_jar, kinesis_asl_assembly_jar) os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() From d04634701413410938a133358fe1d9fbc077645e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 31 Jul 2015 12:10:55 -0700 Subject: [PATCH 208/219] [SPARK-9504] [STREAMING] [TESTS] Use eventually to fix the flaky test The previous code uses `ssc.awaitTerminationOrTimeout(500)`. Since nobody will stop it during `awaitTerminationOrTimeout`, it's just like `sleep(500)`. In a super overloaded Jenkins worker, the receiver may be not able to start in 500 milliseconds. Verified this in the log of https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/39149/ There is no log about starting the receiver before this failure. That's why `assert(runningCount > 0)` failed. This PR replaces `awaitTerminationOrTimeout` with `eventually` which should be more reliable. Author: zsxwing Closes #7823 from zsxwing/SPARK-9504 and squashes the following commits: 7af66a6 [zsxwing] Remove wrong assertion 5ba2c99 [zsxwing] Use eventually to fix the flaky test --- .../apache/spark/streaming/StreamingContextSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 84a5fbb3d95eb..b7db280f63588 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -261,7 +261,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo for (i <- 1 to 4) { logInfo("==================================\n\n\n") ssc = new StreamingContext(sc, Milliseconds(100)) - var runningCount = 0 + @volatile var runningCount = 0 TestReceiver.counter.set(1) val input = ssc.receiverStream(new TestReceiver) input.count().foreachRDD { rdd => @@ -270,14 +270,14 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo logInfo("Count = " + count + ", Running count = " + runningCount) } ssc.start() - ssc.awaitTerminationOrTimeout(500) + eventually(timeout(10.seconds), interval(10.millis)) { + assert(runningCount > 0) + } ssc.stop(stopSparkContext = false, stopGracefully = true) logInfo("Running count = " + runningCount) logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) - assert(runningCount > 0) assert( - (TestReceiver.counter.get() == runningCount + 1) || - (TestReceiver.counter.get() == runningCount + 2), + TestReceiver.counter.get() == runningCount + 1, "Received records = " + TestReceiver.counter.get() + ", " + "processed records = " + runningCount ) From a8340fa7df17e3f0a3658f8b8045ab840845a72a Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Fri, 31 Jul 2015 12:12:22 -0700 Subject: [PATCH 209/219] [SPARK-9481] Add logLikelihood to LocalLDAModel jkbradley Exposes `bound` (variational log likelihood bound) through public API as `logLikelihood`. Also adds unit tests, some DRYing of `LDASuite`, and includes unit tests mentioned in #7760 Author: Feynman Liang Closes #7801 from feynmanliang/SPARK-9481-logLikelihood and squashes the following commits: 6d1b2c9 [Feynman Liang] Negate perplexity definition 5f62b20 [Feynman Liang] Add logLikelihood --- .../spark/mllib/clustering/LDAModel.scala | 20 ++- .../spark/mllib/clustering/LDASuite.scala | 129 +++++++++--------- 2 files changed, 78 insertions(+), 71 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 82281a0daf008..ff7035d2246c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -217,22 +217,28 @@ class LocalLDAModel private[clustering] ( LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix, docConcentration, topicConcentration, gammaShape) } - // TODO - // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + // TODO: declare in LDAModel and override once implemented in DistributedLDAModel + /** + * Calculates a lower bound on the log likelihood of the entire corpus. + * @param documents test corpus to use for calculating log likelihood + * @return variational lower bound on the log likelihood of the entire corpus + */ + def logLikelihood(documents: RDD[(Long, Vector)]): Double = bound(documents, + docConcentration, topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, + vocabSize) /** - * Calculate the log variational bound on perplexity. See Equation (16) in original Online + * Calculate an upper bound bound on perplexity. See Equation (16) in original Online * LDA paper. * @param documents test corpus to use for calculating perplexity - * @return the log perplexity per word + * @return variational upper bound on log perplexity per word */ def logPerplexity(documents: RDD[(Long, Vector)]): Double = { val corpusWords = documents .map { case (_, termCounts) => termCounts.toArray.sum } .sum() - val batchVariationalBound = bound(documents, docConcentration, - topicConcentration, topicsMatrix.toBreeze.toDenseMatrix, gammaShape, k, vocabSize) - val perWordBound = batchVariationalBound / corpusWords + val perWordBound = -logLikelihood(documents) / corpusWords perWordBound } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 695ee3b82efc5..79d2a1cafd1fa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -210,16 +210,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with toy data") { - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - - val docs = sc.parallelize(toydata) + val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -242,30 +233,45 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } - test("LocalLDAModel logPerplexity") { - val k = 2 - val vocabSize = 6 - val alpha = 0.01 - val eta = 0.01 - val gammaShape = 100 - // obtained from LDA model trained in gensim, see below - val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( - 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, - 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + test("LocalLDAModel logLikelihood") { + val ldaModel: LocalLDAModel = toyModel - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val docs = sc.parallelize(toydata) + val docsSingleWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(1))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + val docsRepeatedWord = sc.parallelize(Array(Vectors.sparse(6, Array(0), Array(5))) + .zipWithIndex + .map { case (wordCounts, docId) => (docId.toLong, wordCounts) }) + /* Verify results using gensim: + import numpy as np + from gensim import models + corpus = [ + [(0, 1.0), (1, 1.0)], + [(1, 1.0), (2, 1.0)], + [(0, 1.0), (2, 1.0)], + [(3, 1.0), (4, 1.0)], + [(3, 1.0), (5, 1.0)], + [(4, 1.0), (5, 1.0)]] + np.random.seed(2345) + lda = models.ldamodel.LdaModel( + corpus=corpus, alpha=0.01, eta=0.01, num_topics=2, update_every=0, passes=100, + decay=0.51, offset=1024) + docsSingleWord = [[(0, 1.0)]] + docsRepeatedWord = [[(0, 5.0)]] + print(lda.bound(docsSingleWord)) + > -25.9706969833 + print(lda.bound(docsRepeatedWord)) + > -31.4413908227 + */ - val ldaModel: LocalLDAModel = new LocalLDAModel( - topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + assert(ldaModel.logLikelihood(docsSingleWord) ~== -25.971 relTol 1E-3D) + assert(ldaModel.logLikelihood(docsRepeatedWord) ~== -31.441 relTol 1E-3D) + } + + test("LocalLDAModel logPerplexity") { + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: import numpy as np @@ -285,32 +291,13 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { > -3.69051285096 */ - assert(ldaModel.logPerplexity(docs) ~== -3.690D relTol 1E-3D) + // Gensim's definition of perplexity is negative our (and Stanford NLP's) definition + assert(ldaModel.logPerplexity(docs) ~== 3.690D relTol 1E-3D) } test("LocalLDAModel predict") { - val k = 2 - val vocabSize = 6 - val alpha = 0.01 - val eta = 0.01 - val gammaShape = 100 - // obtained from LDA model trained in gensim, see below - val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( - 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, - 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) - - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - val docs = sc.parallelize(toydata) - - val ldaModel: LocalLDAModel = new LocalLDAModel( - topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + val docs = sc.parallelize(toyData) + val ldaModel: LocalLDAModel = toyModel /* Verify results using gensim: import numpy as np @@ -351,16 +338,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } test("OnlineLDAOptimizer with asymmetric prior") { - def toydata: Array[(Long, Vector)] = Array( - Vectors.sparse(6, Array(0, 1), Array(1, 1)), - Vectors.sparse(6, Array(1, 2), Array(1, 1)), - Vectors.sparse(6, Array(0, 2), Array(1, 1)), - Vectors.sparse(6, Array(3, 4), Array(1, 1)), - Vectors.sparse(6, Array(3, 5), Array(1, 1)), - Vectors.sparse(6, Array(4, 5), Array(1, 1)) - ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } - - val docs = sc.parallelize(toydata) + val docs = sc.parallelize(toyData) val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51) .setGammaShape(1e10) val lda = new LDA().setK(2) @@ -531,4 +509,27 @@ private[clustering] object LDASuite { def getNonEmptyDoc(corpus: Array[(Long, Vector)]): Array[(Long, Vector)] = corpus.filter { case (_, wc: Vector) => Vectors.norm(wc, p = 1.0) != 0.0 } + + def toyData: Array[(Long, Vector)] = Array( + Vectors.sparse(6, Array(0, 1), Array(1, 1)), + Vectors.sparse(6, Array(1, 2), Array(1, 1)), + Vectors.sparse(6, Array(0, 2), Array(1, 1)), + Vectors.sparse(6, Array(3, 4), Array(1, 1)), + Vectors.sparse(6, Array(3, 5), Array(1, 1)), + Vectors.sparse(6, Array(4, 5), Array(1, 1)) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + + def toyModel: LocalLDAModel = { + val k = 2 + val vocabSize = 6 + val alpha = 0.01 + val eta = 0.01 + val gammaShape = 100 + val topics = new DenseMatrix(numRows = vocabSize, numCols = k, values = Array( + 1.86738052, 1.94056535, 1.89981687, 0.0833265, 0.07405918, 0.07940597, + 0.15081551, 0.08637973, 0.12428538, 1.9474897, 1.94615165, 1.95204124)) + val ldaModel: LocalLDAModel = new LocalLDAModel( + topics, Vectors.dense(Array.fill(k)(alpha)), eta, gammaShape) + ldaModel + } } From c0686668ae6a92b6bb4801a55c3b78aedbee816a Mon Sep 17 00:00:00 2001 From: CodingCat Date: Fri, 31 Jul 2015 20:27:00 +0100 Subject: [PATCH 210/219] [SPARK-9202] capping maximum number of executor&driver information kept in Worker https://issues.apache.org/jira/browse/SPARK-9202 Author: CodingCat Closes #7714 from CodingCat/SPARK-9202 and squashes the following commits: 23977fb [CodingCat] add comments about why we don't synchronize finishedExecutors & finishedDrivers dc9772d [CodingCat] addressing the comments e125241 [CodingCat] stylistic fix 80bfe52 [CodingCat] fix JsonProtocolSuite d7d9485 [CodingCat] styistic fix and respect insert ordering 031755f [CodingCat] add license info & stylistic fix c3b5361 [CodingCat] test cases and docs c557b3a [CodingCat] applications are fine 9cac751 [CodingCat] application is fine... ad87ed7 [CodingCat] trimFinishedExecutorsAndDrivers --- .../apache/spark/deploy/worker/Worker.scala | 124 ++++++++++------ .../spark/deploy/worker/ui/WorkerWebUI.scala | 4 +- .../apache/spark/deploy/DeployTestUtils.scala | 89 ++++++++++++ .../spark/deploy/JsonProtocolSuite.scala | 59 ++------ .../spark/deploy/worker/WorkerSuite.scala | 133 +++++++++++++++++- docs/configuration.md | 14 ++ 6 files changed, 329 insertions(+), 94 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 82e9578bbcba5..0276c24f85368 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -25,7 +25,7 @@ import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.JavaConversions._ -import scala.collection.mutable.{HashMap, HashSet} +import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext import scala.util.Random import scala.util.control.NonFatal @@ -115,13 +115,18 @@ private[worker] class Worker( } var workDir: File = null - val finishedExecutors = new HashMap[String, ExecutorRunner] + val finishedExecutors = new LinkedHashMap[String, ExecutorRunner] val drivers = new HashMap[String, DriverRunner] val executors = new HashMap[String, ExecutorRunner] - val finishedDrivers = new HashMap[String, DriverRunner] + val finishedDrivers = new LinkedHashMap[String, DriverRunner] val appDirectories = new HashMap[String, Seq[String]] val finishedApps = new HashSet[String] + val retainedExecutors = conf.getInt("spark.worker.ui.retainedExecutors", + WorkerWebUI.DEFAULT_RETAINED_EXECUTORS) + val retainedDrivers = conf.getInt("spark.worker.ui.retainedDrivers", + WorkerWebUI.DEFAULT_RETAINED_DRIVERS) + // The shuffle service is not actually started unless configured. private val shuffleService = new ExternalShuffleService(conf, securityMgr) @@ -461,25 +466,7 @@ private[worker] class Worker( } case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => - sendToMaster(executorStateChanged) - val fullId = appId + "/" + execId - if (ExecutorState.isFinished(state)) { - executors.get(fullId) match { - case Some(executor) => - logInfo("Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - executors -= fullId - finishedExecutors(fullId) = executor - coresUsed -= executor.cores - memoryUsed -= executor.memory - case None => - logInfo("Unknown Executor " + fullId + " finished with state " + state + - message.map(" message " + _).getOrElse("") + - exitStatus.map(" exitStatus " + _).getOrElse("")) - } - maybeCleanupApplication(appId) - } + handleExecutorStateChanged(executorStateChanged) case KillExecutor(masterUrl, appId, execId) => if (masterUrl != activeMasterUrl) { @@ -523,24 +510,8 @@ private[worker] class Worker( } } - case driverStageChanged @ DriverStateChanged(driverId, state, exception) => { - state match { - case DriverState.ERROR => - logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") - case DriverState.FAILED => - logWarning(s"Driver $driverId exited with failure") - case DriverState.FINISHED => - logInfo(s"Driver $driverId exited successfully") - case DriverState.KILLED => - logInfo(s"Driver $driverId was killed by user") - case _ => - logDebug(s"Driver $driverId changed state to $state") - } - sendToMaster(driverStageChanged) - val driver = drivers.remove(driverId).get - finishedDrivers(driverId) = driver - memoryUsed -= driver.driverDesc.mem - coresUsed -= driver.driverDesc.cores + case driverStateChanged @ DriverStateChanged(driverId, state, exception) => { + handleDriverStateChanged(driverStateChanged) } case ReregisterWithMaster => @@ -614,6 +585,78 @@ private[worker] class Worker( webUi.stop() metricsSystem.stop() } + + private def trimFinishedExecutorsIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedExecutors.size > retainedExecutors) { + finishedExecutors.take(math.max(finishedExecutors.size / 10, 1)).foreach { + case (executorId, _) => finishedExecutors.remove(executorId) + } + } + } + + private def trimFinishedDriversIfNecessary(): Unit = { + // do not need to protect with locks since both WorkerPage and Restful server get data through + // thread-safe RpcEndPoint + if (finishedDrivers.size > retainedDrivers) { + finishedDrivers.take(math.max(finishedDrivers.size / 10, 1)).foreach { + case (driverId, _) => finishedDrivers.remove(driverId) + } + } + } + + private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { + val driverId = driverStateChanged.driverId + val exception = driverStateChanged.exception + val state = driverStateChanged.state + state match { + case DriverState.ERROR => + logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") + case DriverState.FAILED => + logWarning(s"Driver $driverId exited with failure") + case DriverState.FINISHED => + logInfo(s"Driver $driverId exited successfully") + case DriverState.KILLED => + logInfo(s"Driver $driverId was killed by user") + case _ => + logDebug(s"Driver $driverId changed state to $state") + } + sendToMaster(driverStateChanged) + val driver = drivers.remove(driverId).get + finishedDrivers(driverId) = driver + trimFinishedDriversIfNecessary() + memoryUsed -= driver.driverDesc.mem + coresUsed -= driver.driverDesc.cores + } + + private[worker] def handleExecutorStateChanged(executorStateChanged: ExecutorStateChanged): + Unit = { + sendToMaster(executorStateChanged) + val state = executorStateChanged.state + if (ExecutorState.isFinished(state)) { + val appId = executorStateChanged.appId + val fullId = appId + "/" + executorStateChanged.execId + val message = executorStateChanged.message + val exitStatus = executorStateChanged.exitStatus + executors.get(fullId) match { + case Some(executor) => + logInfo("Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + executors -= fullId + finishedExecutors(fullId) = executor + trimFinishedExecutorsIfNecessary() + coresUsed -= executor.cores + memoryUsed -= executor.memory + case None => + logInfo("Unknown Executor " + fullId + " finished with state " + state + + message.map(" message " + _).getOrElse("") + + exitStatus.map(" exitStatus " + _).getOrElse("")) + } + maybeCleanupApplication(appId) + } + } } private[deploy] object Worker extends Logging { @@ -669,5 +712,4 @@ private[deploy] object Worker extends Logging { cmd } } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 334a5b10142aa..709a27233598c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -53,6 +53,8 @@ class WorkerWebUI( } } -private[ui] object WorkerWebUI { +private[worker] object WorkerWebUI { val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR + val DEFAULT_RETAINED_DRIVERS = 1000 + val DEFAULT_RETAINED_EXECUTORS = 1000 } diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala new file mode 100644 index 0000000000000..967aa0976f0ce --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.io.File +import java.util.Date + +import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} +import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.{SecurityManager, SparkConf} + +private[deploy] object DeployTestUtils { + def createAppDesc(): ApplicationDescription = { + val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) + new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") + } + + def createAppInfo() : ApplicationInfo = { + val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, + "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) + appInfo.endTime = JsonConstants.currTimeInMillis + appInfo + } + + def createDriverCommand(): Command = new Command( + "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), + Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") + ) + + def createDriverDesc(): DriverDescription = + new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) + + def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", + createDriverDesc(), new Date()) + + def createWorkerInfo(): WorkerInfo = { + val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") + workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis + workerInfo + } + + def createExecutorRunner(execId: Int): ExecutorRunner = { + new ExecutorRunner( + "appId", + execId, + createAppDesc(), + 4, + 1234, + null, + "workerId", + "host", + 123, + "publicAddress", + new File("sparkHome"), + new File("workDir"), + "akka://worker", + new SparkConf, + Seq("localDir"), + ExecutorState.RUNNING) + } + + def createDriverRunner(driverId: String): DriverRunner = { + val conf = new SparkConf() + new DriverRunner( + conf, + driverId, + new File("workDir"), + new File("sparkHome"), + createDriverDesc(), + null, + "akka://worker", + new SecurityManager(conf)) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 08529e0ef2806..0a9f128a3a6b6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy -import java.io.File import java.util.Date import com.fasterxml.jackson.core.JsonParseException @@ -25,12 +24,14 @@ import org.json4s._ import org.json4s.jackson.JsonMethods import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} -import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo} -import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState} +import org.apache.spark.deploy.worker.ExecutorRunner +import org.apache.spark.{JsonTestUtils, SparkFunSuite} class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { + import org.apache.spark.deploy.DeployTestUtils._ + test("writeApplicationInfo") { val output = JsonProtocol.writeApplicationInfo(createAppInfo()) assertValidJson(output) @@ -50,7 +51,7 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { } test("writeExecutorRunner") { - val output = JsonProtocol.writeExecutorRunner(createExecutorRunner()) + val output = JsonProtocol.writeExecutorRunner(createExecutorRunner(123)) assertValidJson(output) assertValidDataInJson(output, JsonMethods.parse(JsonConstants.executorRunnerJsonStr)) } @@ -77,9 +78,10 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { test("writeWorkerState") { val executors = List[ExecutorRunner]() - val finishedExecutors = List[ExecutorRunner](createExecutorRunner(), createExecutorRunner()) - val drivers = List(createDriverRunner()) - val finishedDrivers = List(createDriverRunner(), createDriverRunner()) + val finishedExecutors = List[ExecutorRunner](createExecutorRunner(123), + createExecutorRunner(123)) + val drivers = List(createDriverRunner("driverId")) + val finishedDrivers = List(createDriverRunner("driverId"), createDriverRunner("driverId")) val stateResponse = new WorkerStateResponse("host", 8080, "workerId", executors, finishedExecutors, drivers, finishedDrivers, "masterUrl", 4, 1234, 4, 1234, "masterWebUiUrl") val output = JsonProtocol.writeWorkerState(stateResponse) @@ -87,47 +89,6 @@ class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { assertValidDataInJson(output, JsonMethods.parse(JsonConstants.workerStateJsonStr)) } - def createAppDesc(): ApplicationDescription = { - val cmd = new Command("mainClass", List("arg1", "arg2"), Map(), Seq(), Seq(), Seq()) - new ApplicationDescription("name", Some(4), 1234, cmd, "appUiUrl") - } - - def createAppInfo() : ApplicationInfo = { - val appInfo = new ApplicationInfo(JsonConstants.appInfoStartTime, - "id", createAppDesc(), JsonConstants.submitDate, null, Int.MaxValue) - appInfo.endTime = JsonConstants.currTimeInMillis - appInfo - } - - def createDriverCommand(): Command = new Command( - "org.apache.spark.FakeClass", Seq("some arg --and-some options -g foo"), - Map(("K1", "V1"), ("K2", "V2")), Seq("cp1", "cp2"), Seq("lp1", "lp2"), Seq("-Dfoo") - ) - - def createDriverDesc(): DriverDescription = - new DriverDescription("hdfs://some-dir/some.jar", 100, 3, false, createDriverCommand()) - - def createDriverInfo(): DriverInfo = new DriverInfo(3, "driver-3", - createDriverDesc(), new Date()) - - def createWorkerInfo(): WorkerInfo = { - val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") - workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis - workerInfo - } - - def createExecutorRunner(): ExecutorRunner = { - new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host", 123, - "publicAddress", new File("sparkHome"), new File("workDir"), "akka://worker", - new SparkConf, Seq("localDir"), ExecutorState.RUNNING) - } - - def createDriverRunner(): DriverRunner = { - val conf = new SparkConf() - new DriverRunner(conf, "driverId", new File("workDir"), new File("sparkHome"), - createDriverDesc(), null, "akka://worker", new SecurityManager(conf)) - } - def assertValidJson(json: JValue) { try { JsonMethods.parse(JsonMethods.compact(json)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index 0f4d3b28d09df..faed4bdc68447 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -17,13 +17,18 @@ package org.apache.spark.deploy.worker -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.Command - import org.scalatest.Matchers +import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} +import org.apache.spark.deploy.master.DriverState +import org.apache.spark.deploy.{Command, ExecutorState} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} + class WorkerSuite extends SparkFunSuite with Matchers { + import org.apache.spark.deploy.DeployTestUtils._ + def cmd(javaOpts: String*): Command = { Command("", Seq.empty, Map.empty, Seq.empty, Seq.empty, Seq(javaOpts : _*)) } @@ -56,4 +61,126 @@ class WorkerSuite extends SparkFunSuite with Matchers { "-Dspark.ssl.useNodeLocalConf=true", "-Dspark.ssl.opt1=y", "-Dspark.ssl.opt2=z") } + + test("test clearing of finishedExecutors (small number of executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 4) + for (i <- 1 until 5) { + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 2) + if (i > 1) { + assert(!worker.finishedExecutors.contains(s"app1/${i - 2}")) + } + assert(worker.executors.size === 4 - i) + } + } + + test("test clearing of finishedExecutors (more executors)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedExecutors", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + worker.executors += s"app1/$i" -> createExecutorRunner(i) + } + // initialize ExecutorStateChanged Message + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", 0, ExecutorState.EXITED, None, None)) + assert(worker.finishedExecutors.size === 1) + assert(worker.executors.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedExecutors.size < 30) { + worker.finishedExecutors.size + 1 + } else { + 28 + } + } + worker.handleExecutorStateChanged( + ExecutorStateChanged("app1", i, ExecutorState.EXITED, None, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedExecutors.contains(s"app1/$j")) + } + } + assert(worker.executors.size === 49 - i) + assert(worker.finishedExecutors.size === expectedValue) + } + } + + test("test clearing of finishedDrivers (small number of drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 2.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 5) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.drivers.size === 4) + assert(worker.finishedDrivers.size === 1) + for (i <- 1 until 5) { + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (i > 1) { + assert(!worker.finishedDrivers.contains(s"driverId-${i - 2}")) + } + assert(worker.drivers.size === 4 - i) + assert(worker.finishedDrivers.size === 2) + } + } + + test("test clearing of finishedDrivers (more drivers)") { + val conf = new SparkConf() + conf.set("spark.worker.ui.retainedDrivers", 30.toString) + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), + "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + // initialize workers + for (i <- 0 until 50) { + val driverId = s"driverId-$i" + worker.drivers += driverId -> createDriverRunner(driverId) + } + // initialize DriverStateChanged Message + worker.handleDriverStateChanged(DriverStateChanged("driverId-0", DriverState.FINISHED, None)) + assert(worker.finishedDrivers.size === 1) + assert(worker.drivers.size === 49) + for (i <- 1 until 50) { + val expectedValue = { + if (worker.finishedDrivers.size < 30) { + worker.finishedDrivers.size + 1 + } else { + 28 + } + } + val driverId = s"driverId-$i" + worker.handleDriverStateChanged(DriverStateChanged(driverId, DriverState.FINISHED, None)) + if (expectedValue == 28) { + for (j <- i - 30 until i - 27) { + assert(!worker.finishedDrivers.contains(s"driverId-$j")) + } + } + assert(worker.drivers.size === 49 - i) + assert(worker.finishedDrivers.size === expectedValue) + } + } } diff --git a/docs/configuration.md b/docs/configuration.md index fd236137cb96e..24b606356a149 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -557,6 +557,20 @@ Apart from these, the following properties are also available, and may be useful collecting. + + spark.worker.ui.retainedExecutors + 1000 + + How many finished executors the Spark UI and status APIs remember before garbage collecting. + + + + spark.worker.ui.retainedDrivers + 1000 + + How many finished drivers the Spark UI and status APIs remember before garbage collecting. + + #### Compression and Serialization From 3c0d2e55210735e0df2f8febb5f63c224af230e3 Mon Sep 17 00:00:00 2001 From: Meihua Wu Date: Fri, 31 Jul 2015 13:01:10 -0700 Subject: [PATCH 211/219] [SPARK-9246] [MLLIB] DistributedLDAModel predict top docs per topic Add topDocumentsPerTopic to DistributedLDAModel. Add ScalaDoc and unit tests. Author: Meihua Wu Closes #7769 from rotationsymmetry/SPARK-9246 and squashes the following commits: 1029e79c [Meihua Wu] clean up code comments a023b82 [Meihua Wu] Update tests to use Long for doc index. 91e5998 [Meihua Wu] Use Long for doc index. b9f70cf [Meihua Wu] Revise topDocumentsPerTopic 26ff3f6 [Meihua Wu] Add topDocumentsPerTopic, scala doc and unit tests --- .../spark/mllib/clustering/LDAModel.scala | 37 +++++++++++++++++++ .../spark/mllib/clustering/LDASuite.scala | 22 +++++++++++ 2 files changed, 59 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index ff7035d2246c2..0cdac84eeb591 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -516,6 +516,43 @@ class DistributedLDAModel private[clustering] ( } } + /** + * Return the top documents for each topic + * + * This is approximate; it may not return exactly the top-weighted documents for each topic. + * To get a more precise set of top documents, increase maxDocumentsPerTopic. + * + * @param maxDocumentsPerTopic Maximum number of documents to collect for each topic. + * @return Array over topics. Each element represent as a pair of matching arrays: + * (IDs for the documents, weights of the topic in these documents). + * For each topic, documents are sorted in order of decreasing topic weights. + */ + def topDocumentsPerTopic(maxDocumentsPerTopic: Int): Array[(Array[Long], Array[Double])] = { + val numTopics = k + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Long)]] = + topicDistributions.mapPartitions { docVertices => + // For this partition, collect the most common docs for each topic in queues: + // queues(topic) = queue of (doc topic, doc ID). + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Long)](maxDocumentsPerTopic)) + for ((docId, docTopics) <- docVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (docTopics(topic) -> docId) + topic += 1 + } + } + Iterator(queues) + }.treeReduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b } + q1 + } + topicsInQueues.map { q => + val (docTopics, docs) = q.toArray.sortBy(-_._1).unzip + (docs.toArray, docTopics.toArray) + } + } + // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 79d2a1cafd1fa..f2b94707fd0ff 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -122,6 +122,28 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // Check: log probabilities assert(model.logLikelihood < 0.0) assert(model.logPrior < 0.0) + + // Check: topDocumentsPerTopic + // Compare it with top documents per topic derived from topicDistributions + val topDocsByTopicDistributions = { n: Int => + Range(0, k).map { topic => + val (doc, docWeights) = topicDistributions.sortBy(-_._2(topic)).take(n).unzip + (doc.toArray, docWeights.map(_(topic)).toArray) + }.toArray + } + + // Top 3 documents per topic + model.topDocumentsPerTopic(3).zip(topDocsByTopicDistributions(3)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } + + // All documents per topic + val q = tinyCorpus.length + model.topDocumentsPerTopic(q).zip(topDocsByTopicDistributions(q)).foreach {case (t1, t2) => + assert(t1._1 === t2._1) + assert(t1._2 === t2._2) + } } test("vertex indexing") { From 060c79aab58efd4ce7353a1b00534de0d9e1de0b Mon Sep 17 00:00:00 2001 From: Sameer Abhyankar Date: Fri, 31 Jul 2015 13:08:55 -0700 Subject: [PATCH 212/219] [SPARK-9056] [STREAMING] Rename configuration `spark.streaming.minRememberDuration` to `spark.streaming.fileStream.minRememberDuration` Rename configuration `spark.streaming.minRememberDuration` to `spark.streaming.fileStream.minRememberDuration` Author: Sameer Abhyankar Author: Sameer Abhyankar Closes #7740 from sabhyankar/spark_branch_9056 and squashes the following commits: d5b2f1f [Sameer Abhyankar] Correct deprecated version to 1.5 1268133 [Sameer Abhyankar] Add {} and indentation ddf9844 [Sameer Abhyankar] Change 4 space indentation to 2 space indentation 1819b5f [Sameer Abhyankar] Use spark.streaming.fileStream.minRememberDuration property in lieu of spark.streaming.minRememberDuration --- core/src/main/scala/org/apache/spark/SparkConf.scala | 4 +++- .../apache/spark/streaming/dstream/FileInputDStream.scala | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 4161792976c7b..08bab4bf2739f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -548,7 +548,9 @@ private[spark] object SparkConf extends Logging { "spark.rpc.askTimeout" -> Seq( AlternateConfig("spark.akka.askTimeout", "1.4")), "spark.rpc.lookupTimeout" -> Seq( - AlternateConfig("spark.akka.lookupTimeout", "1.4")) + AlternateConfig("spark.akka.lookupTimeout", "1.4")), + "spark.streaming.fileStream.minRememberDuration" -> Seq( + AlternateConfig("spark.streaming.minRememberDuration", "1.5")) ) /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index dd4da9d9ca6a2..c358f5b5bd70b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -86,8 +86,10 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( * Files with mod times older than this "window" of remembering will be ignored. So if new * files are visible within this window, then the file will get selected in the next batch. */ - private val minRememberDurationS = - Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.minRememberDuration", "60s")) + private val minRememberDurationS = { + Seconds(ssc.conf.getTimeAsSeconds("spark.streaming.fileStream.minRememberDuration", + ssc.conf.get("spark.streaming.minRememberDuration", "60s"))) + } // This is a def so that it works during checkpoint recovery: private def clock = ssc.scheduler.clock From fbef566a107b47e5fddde0ea65b8587d5039062d Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 31 Jul 2015 13:11:42 -0700 Subject: [PATCH 213/219] [SPARK-9308] [ML] ml.NaiveBayesModel support predicting class probabilities Make NaiveBayesModel support predicting class probabilities, inherit from ProbabilisticClassificationModel. Author: Yanbo Liang Closes #7672 from yanboliang/spark-9308 and squashes the following commits: 25e224c [Yanbo Liang] raw2probabilityInPlace should operate in-place 3ee56d6 [Yanbo Liang] change predictRaw and raw2probabilityInPlace c07e7a2 [Yanbo Liang] ml.NaiveBayesModel support predicting class probabilities --- .../spark/ml/classification/NaiveBayes.scala | 65 ++++++++++++++----- .../ml/classification/NaiveBayesSuite.scala | 54 ++++++++++++++- 2 files changed, 101 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 5be35fe209291..b46b676204e0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { * The input feature values must be nonnegative. */ class NaiveBayes(override val uid: String) - extends Predictor[Vector, NaiveBayes, NaiveBayesModel] + extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams { def this() = this(Identifiable.randomUID("nb")) @@ -106,7 +106,7 @@ class NaiveBayesModel private[ml] ( override val uid: String, val pi: Vector, val theta: Matrix) - extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams { + extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { import OldNaiveBayes.{Bernoulli, Multinomial} @@ -129,29 +129,62 @@ class NaiveBayesModel private[ml] ( throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } - override protected def predict(features: Vector): Double = { + override val numClasses: Int = pi.size + + private def multinomialCalculation(features: Vector) = { + val prob = theta.multiply(features) + BLAS.axpy(1.0, pi, prob) + prob + } + + private def bernoulliCalculation(features: Vector) = { + features.foreachActive((_, value) => + if (value != 0.0 && value != 1.0) { + throw new SparkException( + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.") + } + ) + val prob = thetaMinusNegTheta.get.multiply(features) + BLAS.axpy(1.0, pi, prob) + BLAS.axpy(1.0, negThetaSum.get, prob) + prob + } + + override protected def predictRaw(features: Vector): Vector = { $(modelType) match { case Multinomial => - val prob = theta.multiply(features) - BLAS.axpy(1.0, pi, prob) - prob.argmax + multinomialCalculation(features) case Bernoulli => - features.foreachActive{ (index, value) => - if (value != 0.0 && value != 1.0) { - throw new SparkException( - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features") - } - } - val prob = thetaMinusNegTheta.get.multiply(features) - BLAS.axpy(1.0, pi, prob) - BLAS.axpy(1.0, negThetaSum.get, prob) - prob.argmax + bernoulliCalculation(features) case _ => // This should never happen. throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") } } + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val maxLog = dv.values.max + while (i < size) { + dv.values(i) = math.exp(dv.values(i) - maxLog) + i += 1 + } + val probSum = dv.values.sum + i = 0 + while (i < size) { + dv.values(i) = dv.values(i) / probSum + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in NaiveBayesModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + override def copy(extra: ParamMap): NaiveBayesModel = { copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 264bde3703c5f..aea3d9b694490 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark.ml.classification +import breeze.linalg.{Vector => BV} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.classification.NaiveBayes import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -28,6 +31,8 @@ import org.apache.spark.sql.Row class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { + import NaiveBayes.{Multinomial, Bernoulli} + def validatePrediction(predictionAndLabels: DataFrame): Unit = { val numOfErrorPredictions = predictionAndLabels.collect().count { case Row(prediction: Double, label: Double) => @@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch") } + def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = { + val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v))) + val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v)) + val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze + val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze + val classProbs = logClassProbs.toArray.map(math.exp) + val classProbsSum = classProbs.sum + Vectors.dense(classProbs.map(_ / classProbsSum)) + } + + def validateProbabilities( + featureAndProbabilities: DataFrame, + model: NaiveBayesModel, + modelType: String): Unit = { + featureAndProbabilities.collect().foreach { + case Row(features: Vector, probability: Vector) => { + assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) + val expected = modelType match { + case Multinomial => + expectedMultinomialProbabilities(model, features) + case Bernoulli => + expectedBernoulliProbabilities(model, features) + case _ => + throw new UnknownError(s"Invalid modelType: $modelType.") + } + assert(probability ~== expected relTol 1.0e-10) + } + } + } + test("params") { ParamsSuite.checkParams(new NaiveBayes) val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)), @@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 17, "multinomial")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "multinomial") } test("Naive Bayes Bernoulli") { @@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput( piArray, thetaArray, nPoints, 20, "bernoulli")) - val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") + val predictionAndLabels = model.transform(validationDataset).select("prediction", "label") validatePrediction(predictionAndLabels) + + val featureAndProbabilities = model.transform(validationDataset) + .select("features", "probability") + validateProbabilities(featureAndProbabilities, model, "bernoulli") } } From 815c8245f47e61226a04e2e02f508457b5e9e536 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 31 Jul 2015 13:45:12 -0700 Subject: [PATCH 214/219] [SPARK-9466] [SQL] Increate two timeouts in CliSuite. Hopefully this can resolve the flakiness of this suite. JIRA: https://issues.apache.org/jira/browse/SPARK-9466 Author: Yin Huai Closes #7777 from yhuai/SPARK-9466 and squashes the following commits: e0e3a86 [Yin Huai] Increate the timeout. --- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 13b0c5951dddc..df80d04b40801 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -137,7 +137,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { } test("Single command with --database") { - runCliWithin(1.minute)( + runCliWithin(2.minute)( "CREATE DATABASE hive_test_db;" -> "OK", "USE hive_test_db;" @@ -148,7 +148,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfter with Logging { -> "Time taken: " ) - runCliWithin(1.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( + runCliWithin(2.minute, Seq("--database", "hive_test_db", "-e", "SHOW TABLES;"))( "" -> "OK", "" From 873ab0f9692d8ea6220abdb8d9200041068372a8 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 31 Jul 2015 13:45:28 -0700 Subject: [PATCH 215/219] [SPARK-9490] [DOCS] [MLLIB] MLlib evaluation metrics guide example python code uses deprecated print statement Use print(x) not print x for Python 3 in eval examples CC sethah mengxr -- just wanted to close this out before 1.5 Author: Sean Owen Closes #7822 from srowen/SPARK-9490 and squashes the following commits: 01abeba [Sean Owen] Change "print x" to "print(x)" in the rest of the docs too bd7f7fb [Sean Owen] Use print(x) not print x for Python 3 in eval examples --- docs/ml-guide.md | 2 +- docs/mllib-evaluation-metrics.md | 66 ++++++++++++++--------------- docs/mllib-feature-extraction.md | 2 +- docs/mllib-statistics.md | 20 ++++----- docs/quick-start.md | 2 +- docs/sql-programming-guide.md | 6 +-- docs/streaming-programming-guide.md | 2 +- 7 files changed, 50 insertions(+), 50 deletions(-) diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 8c46adf256a9a..b6ca50e98db02 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -561,7 +561,7 @@ test = sc.parallelize([(4L, "spark i j k"), prediction = model.transform(test) selected = prediction.select("id", "text", "prediction") for row in selected.collect(): - print row + print(row) sc.stop() {% endhighlight %} diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md index 4ca0bb06b26a6..7066d5c97418c 100644 --- a/docs/mllib-evaluation-metrics.md +++ b/docs/mllib-evaluation-metrics.md @@ -302,10 +302,10 @@ predictionAndLabels = test.map(lambda lp: (float(model.predict(lp.features)), lp metrics = BinaryClassificationMetrics(predictionAndLabels) # Area under precision-recall curve -print "Area under PR = %s" % metrics.areaUnderPR +print("Area under PR = %s" % metrics.areaUnderPR) # Area under ROC curve -print "Area under ROC = %s" % metrics.areaUnderROC +print("Area under ROC = %s" % metrics.areaUnderROC) {% endhighlight %} @@ -606,24 +606,24 @@ metrics = MulticlassMetrics(predictionAndLabels) precision = metrics.precision() recall = metrics.recall() f1Score = metrics.fMeasure() -print "Summary Stats" -print "Precision = %s" % precision -print "Recall = %s" % recall -print "F1 Score = %s" % f1Score +print("Summary Stats") +print("Precision = %s" % precision) +print("Recall = %s" % recall) +print("F1 Score = %s" % f1Score) # Statistics by class labels = data.map(lambda lp: lp.label).distinct().collect() for label in sorted(labels): - print "Class %s precision = %s" % (label, metrics.precision(label)) - print "Class %s recall = %s" % (label, metrics.recall(label)) - print "Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)) + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0))) # Weighted stats -print "Weighted recall = %s" % metrics.weightedRecall -print "Weighted precision = %s" % metrics.weightedPrecision -print "Weighted F(1) Score = %s" % metrics.weightedFMeasure() -print "Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5) -print "Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate +print("Weighted recall = %s" % metrics.weightedRecall) +print("Weighted precision = %s" % metrics.weightedPrecision) +print("Weighted F(1) Score = %s" % metrics.weightedFMeasure()) +print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5)) +print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate) {% endhighlight %} @@ -881,28 +881,28 @@ scoreAndLabels = sc.parallelize([ metrics = MultilabelMetrics(scoreAndLabels) # Summary stats -print "Recall = %s" % metrics.recall() -print "Precision = %s" % metrics.precision() -print "F1 measure = %s" % metrics.f1Measure() -print "Accuracy = %s" % metrics.accuracy +print("Recall = %s" % metrics.recall()) +print("Precision = %s" % metrics.precision()) +print("F1 measure = %s" % metrics.f1Measure()) +print("Accuracy = %s" % metrics.accuracy) # Individual label stats labels = scoreAndLabels.flatMap(lambda x: x[1]).distinct().collect() for label in labels: - print "Class %s precision = %s" % (label, metrics.precision(label)) - print "Class %s recall = %s" % (label, metrics.recall(label)) - print "Class %s F1 Measure = %s" % (label, metrics.f1Measure(label)) + print("Class %s precision = %s" % (label, metrics.precision(label))) + print("Class %s recall = %s" % (label, metrics.recall(label))) + print("Class %s F1 Measure = %s" % (label, metrics.f1Measure(label))) # Micro stats -print "Micro precision = %s" % metrics.microPrecision -print "Micro recall = %s" % metrics.microRecall -print "Micro F1 measure = %s" % metrics.microF1Measure +print("Micro precision = %s" % metrics.microPrecision) +print("Micro recall = %s" % metrics.microRecall) +print("Micro F1 measure = %s" % metrics.microF1Measure) # Hamming loss -print "Hamming loss = %s" % metrics.hammingLoss +print("Hamming loss = %s" % metrics.hammingLoss) # Subset accuracy -print "Subset accuracy = %s" % metrics.subsetAccuracy +print("Subset accuracy = %s" % metrics.subsetAccuracy) {% endhighlight %} @@ -1283,10 +1283,10 @@ scoreAndLabels = predictions.join(ratingsTuple).map(lambda tup: tup[1]) metrics = RegressionMetrics(scoreAndLabels) # Root mean sqaured error -print "RMSE = %s" % metrics.rootMeanSquaredError +print("RMSE = %s" % metrics.rootMeanSquaredError) # R-squared -print "R-squared = %s" % metrics.r2 +print("R-squared = %s" % metrics.r2) {% endhighlight %} @@ -1479,17 +1479,17 @@ valuesAndPreds = parsedData.map(lambda p: (float(model.predict(p.features)), p.l metrics = RegressionMetrics(valuesAndPreds) # Squared Error -print "MSE = %s" % metrics.meanSquaredError -print "RMSE = %s" % metrics.rootMeanSquaredError +print("MSE = %s" % metrics.meanSquaredError) +print("RMSE = %s" % metrics.rootMeanSquaredError) # R-squared -print "R-squared = %s" % metrics.r2 +print("R-squared = %s" % metrics.r2) # Mean absolute error -print "MAE = %s" % metrics.meanAbsoluteError +print("MAE = %s" % metrics.meanAbsoluteError) # Explained variance -print "Explained variance = %s" % metrics.explainedVariance +print("Explained variance = %s" % metrics.explainedVariance) {% endhighlight %} diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index a69e41e2a1936..de86aba2ae627 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -221,7 +221,7 @@ model = word2vec.fit(inp) synonyms = model.findSynonyms('china', 40) for word, cosine_distance in synonyms: - print "{}: {}".format(word, cosine_distance) + print("{}: {}".format(word, cosine_distance)) {% endhighlight %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index de5d6485f9b5f..be04d0b4b53a8 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -95,9 +95,9 @@ mat = ... # an RDD of Vectors # Compute column summary statistics. summary = Statistics.colStats(mat) -print summary.mean() -print summary.variance() -print summary.numNonzeros() +print(summary.mean()) +print(summary.variance()) +print(summary.numNonzeros()) {% endhighlight %} @@ -183,12 +183,12 @@ seriesY = ... # must have the same number of partitions and cardinality as serie # Compute the correlation using Pearson's method. Enter "spearman" for Spearman's method. If a # method is not specified, Pearson's method will be used by default. -print Statistics.corr(seriesX, seriesY, method="pearson") +print(Statistics.corr(seriesX, seriesY, method="pearson")) data = ... # an RDD of Vectors # calculate the correlation matrix using Pearson's method. Use "spearman" for Spearman's method. # If a method is not specified, Pearson's method will be used by default. -print Statistics.corr(data, method="pearson") +print(Statistics.corr(data, method="pearson")) {% endhighlight %} @@ -398,14 +398,14 @@ vec = Vectors.dense(...) # a vector composed of the frequencies of events # compute the goodness of fit. If a second vector to test against is not supplied as a parameter, # the test runs against a uniform distribution. goodnessOfFitTestResult = Statistics.chiSqTest(vec) -print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, - # test statistic, the method used, and the null hypothesis. +print(goodnessOfFitTestResult) # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. mat = Matrices.dense(...) # a contingency matrix # conduct Pearson's independence test on the input contingency matrix independenceTestResult = Statistics.chiSqTest(mat) -print independenceTestResult # summary of the test including the p-value, degrees of freedom... +print(independenceTestResult) # summary of the test including the p-value, degrees of freedom... obs = sc.parallelize(...) # LabeledPoint(feature, label) . @@ -415,8 +415,8 @@ obs = sc.parallelize(...) # LabeledPoint(feature, label) . featureTestResults = Statistics.chiSqTest(obs) for i, result in enumerate(featureTestResults): - print "Column $d:" % (i + 1) - print result + print("Column $d:" % (i + 1)) + print(result) {% endhighlight %} diff --git a/docs/quick-start.md b/docs/quick-start.md index bb39e4111f244..ce2cc9d2169cd 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -406,7 +406,7 @@ logData = sc.textFile(logFile).cache() numAs = logData.filter(lambda s: 'a' in s).count() numBs = logData.filter(lambda s: 'b' in s).count() -print "Lines with a: %i, lines with b: %i" % (numAs, numBs) +print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) {% endhighlight %} diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 95945eb7fc8a0..d31baa080cbce 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -570,7 +570,7 @@ teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 1 # The results of SQL queries are RDDs and support all the normal RDD operations. teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %} @@ -752,7 +752,7 @@ results = sqlContext.sql("SELECT name FROM people") # The results of SQL queries are RDDs and support all the normal RDD operations. names = results.map(lambda p: "Name: " + p.name) for name in names.collect(): - print name + print(name) {% endhighlight %} @@ -1006,7 +1006,7 @@ parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): - print teenName + print(teenName) {% endhighlight %} diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 2f3013b533eb0..4663b3f14c527 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1525,7 +1525,7 @@ def getSqlContextInstance(sparkContext): words = ... # DStream of strings def process(time, rdd): - print "========= %s =========" % str(time) + print("========= %s =========" % str(time)) try: # Get the singleton instance of SQLContext sqlContext = getSqlContextInstance(rdd.context) From 6e5fd613ea4b9aa0ab485ba681277a51a4367168 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 31 Jul 2015 21:51:55 +0100 Subject: [PATCH 216/219] [SPARK-9507] [BUILD] Remove dependency reduced POM hack now that shade plugin is updated Update to shade plugin 2.4.1, which removes the need for the dependency-reduced-POM workaround and the 'release' profile. Fix management of shade plugin version so children inherit it; bump assembly plugin version while here See https://issues.apache.org/jira/browse/SPARK-8819 I verified that `mvn clean package -DskipTests` works with Maven 3.3.3. pwendell are you up for trying this for the 1.5.0 release? Author: Sean Owen Closes #7826 from srowen/SPARK-9507 and squashes the following commits: e0b0fd2 [Sean Owen] Update to shade plugin 2.4.1, which removes the need for the dependency-reduced-POM workaround and the 'release' profile. Fix management of shade plugin version so children inherit it; bump assembly plugin version while here --- dev/create-release/create-release.sh | 4 ++-- pom.xml | 33 +++++----------------------- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 86a7a4068c40e..4311c8c9e4ca6 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -118,13 +118,13 @@ if [[ ! "$@" =~ --skip-publish ]]; then rm -rf $SPARK_REPO - build/mvn -DskipTests -Pyarn -Phive -Prelease\ + build/mvn -DskipTests -Pyarn -Phive \ -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install ./dev/change-scala-version.sh 2.11 - build/mvn -DskipTests -Pyarn -Phive -Prelease\ + build/mvn -DskipTests -Pyarn -Phive \ -Dscala-2.11 -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ clean install diff --git a/pom.xml b/pom.xml index e351c7c19df96..1371a1b6bd9f1 100644 --- a/pom.xml +++ b/pom.xml @@ -160,9 +160,6 @@ 2.4.4 1.1.1.7 1.1.2 - - false - ${java.home} - ${create.dependency.reduced.pom} @@ -1836,26 +1835,6 @@ - - - release - - - true - - -