From 58b9ca1e6f7768b23e752dabc30468c06d0e1c57 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 10 Feb 2020 16:23:44 +0900 Subject: [PATCH 001/185] [SPARK-30592][SQL][FOLLOWUP] Add some round-trip test cases ### What changes were proposed in this pull request? Add round-trip tests for CSV and JSON functions as https://github.com/apache/spark/pull/27317#discussion_r376745135 asked. ### Why are the changes needed? improve test coverage ### Does this PR introduce any user-facing change? no ### How was this patch tested? add uts Closes #27510 from yaooqinn/SPARK-30592-F. Authored-by: Kent Yao Signed-off-by: HyukjinKwon --- .../resources/sql-tests/inputs/interval.sql | 14 +++++-- .../sql-tests/results/ansi/interval.sql.out | 38 ++++++++----------- .../sql-tests/results/interval.sql.out | 38 ++++++++----------- 3 files changed, 40 insertions(+), 50 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql index fb6c485f619ae..a4e621e9639d4 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql @@ -222,7 +222,13 @@ select a * 1.1 from values (interval '-2147483648 months', interval '2147483647 select a / 0.5 from values (interval '-2147483648 months', interval '2147483647 months') t(a, b); -- interval support for csv and json functions -SELECT from_csv('1, 1 day', 'a INT, b interval'); -SELECT to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)); -SELECT from_json('{"a":"1 days"}', 'a interval'); -SELECT to_json(map('a', interval 25 month 100 day 130 minute)); +SELECT + from_csv('1, 1 day', 'a INT, b interval'), + to_csv(from_csv('1, 1 day', 'a INT, b interval')), + to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), + from_csv(to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), 'a interval, b interval'); +SELECT + from_json('{"a":"1 days"}', 'a interval'), + to_json(from_json('{"a":"1 days"}', 'a interval')), + to_json(map('a', interval 25 month 100 day 130 minute)), + from_json(to_json(map('a', interval 25 month 100 day 130 minute)), 'a interval'); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index f37049064d989..7fdb4c53d1dcb 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 101 +-- Number of queries: 99 -- !query @@ -988,32 +988,24 @@ integer overflow -- !query -SELECT from_csv('1, 1 day', 'a INT, b interval') --- !query schema -struct> --- !query output -{"a":1,"b":1 days} - - --- !query -SELECT to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)) --- !query schema -struct --- !query output -2 years 8 months,1 hours 10 minutes - - --- !query -SELECT from_json('{"a":"1 days"}', 'a interval') +SELECT + from_csv('1, 1 day', 'a INT, b interval'), + to_csv(from_csv('1, 1 day', 'a INT, b interval')), + to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), + from_csv(to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), 'a interval, b interval') -- !query schema -struct> +struct,to_csv(from_csv(1, 1 day)):string,to_csv(named_struct(a, INTERVAL '2 years 8 months', b, INTERVAL '1 hours 10 minutes')):string,from_csv(to_csv(named_struct(a, INTERVAL '2 years 8 months', b, INTERVAL '1 hours 10 minutes'))):struct> -- !query output -{"a":1 days} +{"a":1,"b":1 days} 1,1 days 2 years 8 months,1 hours 10 minutes {"a":2 years 8 months,"b":1 hours 10 minutes} -- !query -SELECT to_json(map('a', interval 25 month 100 day 130 minute)) +SELECT + from_json('{"a":"1 days"}', 'a interval'), + to_json(from_json('{"a":"1 days"}', 'a interval')), + to_json(map('a', interval 25 month 100 day 130 minute)), + from_json(to_json(map('a', interval 25 month 100 day 130 minute)), 'a interval') -- !query schema -struct +struct,to_json(from_json({"a":"1 days"})):string,to_json(map(a, INTERVAL '2 years 1 months 100 days 2 hours 10 minutes')):string,from_json(to_json(map(a, INTERVAL '2 years 1 months 100 days 2 hours 10 minutes'))):struct> -- !query output -{"a":"2 years 1 months 100 days 2 hours 10 minutes"} +{"a":1 days} {"a":"1 days"} {"a":"2 years 1 months 100 days 2 hours 10 minutes"} {"a":2 years 1 months 100 days 2 hours 10 minutes} diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 94b4f15815ca5..3c4b4301d0025 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 101 +-- Number of queries: 99 -- !query @@ -969,32 +969,24 @@ integer overflow -- !query -SELECT from_csv('1, 1 day', 'a INT, b interval') --- !query schema -struct> --- !query output -{"a":1,"b":1 days} - - --- !query -SELECT to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)) --- !query schema -struct --- !query output -2 years 8 months,1 hours 10 minutes - - --- !query -SELECT from_json('{"a":"1 days"}', 'a interval') +SELECT + from_csv('1, 1 day', 'a INT, b interval'), + to_csv(from_csv('1, 1 day', 'a INT, b interval')), + to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), + from_csv(to_csv(named_struct('a', interval 32 month, 'b', interval 70 minute)), 'a interval, b interval') -- !query schema -struct> +struct,to_csv(from_csv(1, 1 day)):string,to_csv(named_struct(a, INTERVAL '2 years 8 months', b, INTERVAL '1 hours 10 minutes')):string,from_csv(to_csv(named_struct(a, INTERVAL '2 years 8 months', b, INTERVAL '1 hours 10 minutes'))):struct> -- !query output -{"a":1 days} +{"a":1,"b":1 days} 1,1 days 2 years 8 months,1 hours 10 minutes {"a":2 years 8 months,"b":1 hours 10 minutes} -- !query -SELECT to_json(map('a', interval 25 month 100 day 130 minute)) +SELECT + from_json('{"a":"1 days"}', 'a interval'), + to_json(from_json('{"a":"1 days"}', 'a interval')), + to_json(map('a', interval 25 month 100 day 130 minute)), + from_json(to_json(map('a', interval 25 month 100 day 130 minute)), 'a interval') -- !query schema -struct +struct,to_json(from_json({"a":"1 days"})):string,to_json(map(a, INTERVAL '2 years 1 months 100 days 2 hours 10 minutes')):string,from_json(to_json(map(a, INTERVAL '2 years 1 months 100 days 2 hours 10 minutes'))):struct> -- !query output -{"a":"2 years 1 months 100 days 2 hours 10 minutes"} +{"a":1 days} {"a":"1 days"} {"a":"2 years 1 months 100 days 2 hours 10 minutes"} {"a":2 years 1 months 100 days 2 hours 10 minutes} From 70e545a94d47afb2848c24e81c908d28d41016da Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Mon, 10 Feb 2020 19:04:49 +0800 Subject: [PATCH 002/185] [SPARK-30757][SQL][DOC] Update the doc on TableCatalog.alterTable's behavior ### What changes were proposed in this pull request? This PR updates the documentation on `TableCatalog.alterTable`s behavior on the order by which the requested changes are applied. It now explicitly mentions that the changes are applied in the order given. ### Why are the changes needed? The current documentation on `TableCatalog.alterTable` doesn't mention which order the requested changes are applied. It will be useful to explicitly document this behavior so that the user can expect the behavior. For example, `REPLACE COLUMNS` needs to delete columns before adding new columns, and if the order is guaranteed by `alterTable`, it's much easier to work with the catalog API. ### Does this PR introduce any user-facing change? Yes, document change. ### How was this patch tested? Not added (doc changes). Closes #27496 from imback82/catalog_table_alter_table. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/connector/catalog/TableCatalog.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index a69b23bf84d0c..2f102348ec517 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -134,6 +134,8 @@ Table createTable( * Implementations may reject the requested changes. If any change is rejected, none of the * changes should be applied to the table. *

+ * The requested changes must be applied in the order given. + *

* If the catalog supports views and contains a view for the identifier and not a table, this * must throw {@link NoSuchTableException}. * From 5a240603fd920e3cb5d9ef49c31d46df8a630d8c Mon Sep 17 00:00:00 2001 From: jiake Date: Mon, 10 Feb 2020 21:48:00 +0800 Subject: [PATCH 003/185] [SPARK-30719][SQL] Add unit test to verify the log warning print when intentionally skip AQE ### What changes were proposed in this pull request? This is a follow up in [#27452](https://github.com/apache/spark/pull/27452). Add a unit test to verify whether the log warning is print when intentionally skip AQE. ### Why are the changes needed? Add unit test ### Does this PR introduce any user-facing change? No ### How was this patch tested? adding unit test Closes #27515 from JkSelf/aqeLoggingWarningTest. Authored-by: jiake Signed-off-by: Wenchen Fan --- .../adaptive/AdaptiveQueryExecSuite.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 96e977221e512..a2071903bea7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -789,4 +789,19 @@ class AdaptiveQueryExecSuite assert(plan.isInstanceOf[AdaptiveSparkPlanExec]) } } + + test("SPARK-30719: do not log warning if intentionally skip AQE") { + val testAppender = new LogAppender("aqe logging warning test when skip") + withLogAppender(testAppender) { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val plan = sql("SELECT * FROM testData").queryExecution.executedPlan + assert(!plan.isInstanceOf[AdaptiveSparkPlanExec]) + } + } + assert(!testAppender.loggingEvents + .exists(msg => msg.getRenderedMessage.contains( + s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} is" + + s" enabled but is not supported for"))) + } } From b2011a295bd78b3693a516e049e90250366b8f52 Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Mon, 10 Feb 2020 23:41:39 +0800 Subject: [PATCH 004/185] [SPARK-30326][SQL] Raise exception if analyzer exceed max iterations ### What changes were proposed in this pull request? Enhance RuleExecutor strategy to take different actions when exceeding max iterations. And raise exception if analyzer exceed max iterations. ### Why are the changes needed? Currently, both analyzer and optimizer just log warning message if rule execution exceed max iterations. They should have different behavior. Analyzer should raise exception to indicates the plan is not fixed after max iterations, while optimizer just log warning to keep the current plan. This is more feasible after SPARK-30138 was introduced. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Add test in AnalysisSuite Closes #26977 from Eric5553/EnhanceMaxIterations. Authored-by: Eric Wu <492960551@qq.com> Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 10 ++++++- .../sql/catalyst/optimizer/Optimizer.scala | 5 +++- .../sql/catalyst/rules/RuleExecutor.scala | 27 ++++++++++++++++--- .../sql/catalyst/analysis/AnalysisSuite.scala | 25 ++++++++++++++++- 4 files changed, 60 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 75f1aa7185ef3..ce82b3b567b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -176,7 +176,15 @@ class Analyzer( def resolver: Resolver = conf.resolver - protected val fixedPoint = FixedPoint(maxIterations) + /** + * If the plan cannot be resolved within maxIterations, analyzer will throw exception to inform + * user to increase the value of SQLConf.ANALYZER_MAX_ITERATIONS. + */ + protected val fixedPoint = + FixedPoint( + maxIterations, + errorOnExceed = true, + maxIterationsSetting = SQLConf.ANALYZER_MAX_ITERATIONS.key) /** * Override to provide additional rules for the "Resolution" batch. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 935d62015afa1..08acac18f48bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -53,7 +53,10 @@ abstract class Optimizer(catalogManager: CatalogManager) "PartitionPruning", "Extract Python UDFs") - protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations) + protected def fixedPoint = + FixedPoint( + SQLConf.get.optimizerMaxIterations, + maxIterationsSetting = SQLConf.OPTIMIZER_MAX_ITERATIONS.key) /** * Defines the default rule batches in the Optimizer. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 287ae0e8e9f67..da5242bee28e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -45,7 +45,17 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { * An execution strategy for rules that indicates the maximum number of executions. If the * execution reaches fix point (i.e. converge) before maxIterations, it will stop. */ - abstract class Strategy { def maxIterations: Int } + abstract class Strategy { + + /** The maximum number of executions. */ + def maxIterations: Int + + /** Whether to throw exception when exceeding the maximum number. */ + def errorOnExceed: Boolean = false + + /** The key of SQLConf setting to tune maxIterations */ + def maxIterationsSetting: String = null + } /** A strategy that is run once and idempotent. */ case object Once extends Strategy { val maxIterations = 1 } @@ -54,7 +64,10 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { * A strategy that runs until fix point or maxIterations times, whichever comes first. * Especially, a FixedPoint(1) batch is supposed to run only once. */ - case class FixedPoint(maxIterations: Int) extends Strategy + case class FixedPoint( + override val maxIterations: Int, + override val errorOnExceed: Boolean = false, + override val maxIterationsSetting: String = null) extends Strategy /** A batch of rules. */ protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) @@ -155,8 +168,14 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}" - if (Utils.isTesting) { + val endingMsg = if (batch.strategy.maxIterationsSetting == null) { + "." + } else { + s", please set '${batch.strategy.maxIterationsSetting}' to a larger value." + } + val message = s"Max iterations (${iteration - 1}) reached for batch ${batch.name}" + + s"$endingMsg" + if (Utils.isTesting || batch.strategy.errorOnExceed) { throw new TreeNodeException(curPlan, message, null) } else { logWarning(message) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index c747d394b1bc2..d38513319388b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -25,9 +25,10 @@ import org.scalatest.Matchers import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count, Sum} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan @@ -745,4 +746,26 @@ class AnalysisSuite extends AnalysisTest with Matchers { CollectMetrics("evt1", sumWithFilter :: Nil, testRelation), "aggregates with filter predicate are not allowed" :: Nil) } + + test("Analysis exceed max iterations") { + // RuleExecutor only throw exception or log warning when the rule is supposed to run + // more than once. + val maxIterations = 2 + val conf = new SQLConf().copy(SQLConf.ANALYZER_MAX_ITERATIONS -> maxIterations) + val testAnalyzer = new Analyzer( + new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf), conf) + + val plan = testRelation2.select( + $"a" / Literal(2) as "div1", + $"a" / $"b" as "div2", + $"a" / $"c" as "div3", + $"a" / $"d" as "div4", + $"e" / $"e" as "div5") + + val message = intercept[TreeNodeException[LogicalPlan]] { + testAnalyzer.execute(plan) + }.getMessage + assert(message.startsWith(s"Max iterations ($maxIterations) reached for batch Resolution, " + + s"please set '${SQLConf.ANALYZER_MAX_ITERATIONS.key}' to a larger value.")) + } } From acfdb46a60fc06dac0af55951492d74b7073f546 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 10 Feb 2020 10:45:00 -0800 Subject: [PATCH 005/185] [SPARK-27946][SQL][FOLLOW-UP] Change doc and error message for SHOW CREATE TABLE ### What changes were proposed in this pull request? This is a follow-up for #24938 to tweak error message and migration doc. ### Why are the changes needed? Making user know workaround if SHOW CREATE TABLE doesn't work for some Hive tables. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Existing unit tests. Closes #27505 from viirya/SPARK-27946-followup. Authored-by: Liang-Chi Hsieh Signed-off-by: Liang-Chi Hsieh --- docs/sql-migration-guide.md | 2 +- .../org/apache/spark/sql/execution/command/tables.scala | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index be0fe32ded99b..26eb5838892b4 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -326,7 +326,7 @@ license: | - Since Spark 3.0, `SHOW TBLPROPERTIES` will cause `AnalysisException` if the table does not exist. In Spark version 2.4 and earlier, this scenario caused `NoSuchTableException`. Also, `SHOW TBLPROPERTIES` on a temporary view will cause `AnalysisException`. In Spark version 2.4 and earlier, it returned an empty result. - - Since Spark 3.0, `SHOW CREATE TABLE` will always return Spark DDL, even when the given table is a Hive serde table. For Hive DDL, please use `SHOW CREATE TABLE AS SERDE` command instead. + - Since Spark 3.0, `SHOW CREATE TABLE` will always return Spark DDL, even when the given table is a Hive serde table. For generating Hive DDL, please use `SHOW CREATE TABLE AS SERDE` command instead. - Since Spark 3.0, we upgraded the built-in Hive from 1.2 to 2.3. This may need to set `spark.sql.hive.metastore.version` and `spark.sql.hive.metastore.jars` according to the version of the Hive metastore. For example: set `spark.sql.hive.metastore.version` to `1.2.1` and `spark.sql.hive.metastore.jars` to `maven` if your Hive metastore version is 1.2.1. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 468ca505cce1f..90dbdf5515d4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -1076,7 +1076,9 @@ case class ShowCreateTableCommand(table: TableIdentifier) "Failed to execute SHOW CREATE TABLE against table " + s"${tableMetadata.identifier}, which is created by Hive and uses the " + "following unsupported feature(s)\n" + - tableMetadata.unsupportedFeatures.map(" - " + _).mkString("\n") + tableMetadata.unsupportedFeatures.map(" - " + _).mkString("\n") + ". " + + s"Please use `SHOW CREATE TABLE ${tableMetadata.identifier} AS SERDE` " + + "to show Hive DDL instead." ) } @@ -1086,7 +1088,9 @@ case class ShowCreateTableCommand(table: TableIdentifier) if ("true".equalsIgnoreCase(tableMetadata.properties.getOrElse("transactional", "false"))) { throw new AnalysisException( - "SHOW CREATE TABLE doesn't support transactional Hive table") + "SHOW CREATE TABLE doesn't support transactional Hive table. " + + s"Please use `SHOW CREATE TABLE ${tableMetadata.identifier} AS SERDE` " + + "to show Hive DDL instead.") } convertTableMetadata(tableMetadata) From 4439b29bd2ac0c3cc4c6ceea825fc797ff0029a3 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Mon, 10 Feb 2020 10:56:43 -0800 Subject: [PATCH 006/185] Revert "[SPARK-30245][SQL] Add cache for Like and RLike when pattern is not static" ### What changes were proposed in this pull request? This reverts commit 8ce7962931680c204e84dd75783b1c943ea9c525. There's variable name conflicts with https://github.com/apache/spark/commit/8aebc80e0e67bcb1aa300b8c8b1a209159237632#diff-39298b470865a4cbc67398a4ea11e767. This can be cleanly ported back to branch-3.0. ### Why are the changes needed? Performance investigation were not made enough and it's not clear if it really beneficial or now. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Jenkins tests. Closes #27514 from HyukjinKwon/revert-cache-PR. Authored-by: HyukjinKwon Signed-off-by: Xiao Li --- .../expressions/regexpExpressions.scala | 21 ++++++------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index c9ddc70bf5bc6..f84c476ea5807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -177,6 +177,8 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) """) } } else { + val patternStr = ctx.freshName("patternStr") + val compiledPattern = ctx.freshName("compiledPattern") // We need double escape to avoid org.codehaus.commons.compiler.CompileException. // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. // '\"' will cause exception 'Line break in literal not allowed'. @@ -185,17 +187,11 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) } else { escapeChar } - val patternStr = ctx.freshName("patternStr") - val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern") - val lastPatternStr = ctx.addMutableState(classOf[String].getName, "lastPatternStr") - nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => { s""" String $patternStr = $eval2.toString(); - if (!$patternStr.equals($lastPatternStr)) { - $compiledPattern = $patternClass.compile($escapeFunc($patternStr, '$newEscapeChar')); - $lastPatternStr = $patternStr; - } + $patternClass $compiledPattern = $patternClass.compile( + $escapeFunc($patternStr, '$newEscapeChar')); ${ev.value} = $compiledPattern.matcher($eval1.toString()).matches(); """ }) @@ -278,16 +274,11 @@ case class RLike(left: Expression, right: Expression) } } else { val rightStr = ctx.freshName("rightStr") - val pattern = ctx.addMutableState(patternClass, "pattern") - val lastRightStr = ctx.addMutableState(classOf[String].getName, "lastRightStr") - + val pattern = ctx.freshName("pattern") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = $eval2.toString(); - if (!$rightStr.equals($lastRightStr)) { - $pattern = $patternClass.compile($rightStr); - $lastRightStr = $rightStr; - } + $patternClass $pattern = $patternClass.compile($rightStr); ${ev.value} = $pattern.matcher($eval1.toString()).find(0); """ }) From 3c1c9b48fcca1a714e6c2a3045b512598438d672 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 10 Feb 2020 12:51:37 -0800 Subject: [PATCH 007/185] [SPARK-30759][SQL] Initialize cache for foldable patterns in StringRegexExpression ### What changes were proposed in this pull request? In the PR, I propose to fix `cache` initialization in `StringRegexExpression` by changing `case Literal(value: String, StringType)` to `case p: Expression if p.foldable` ### Why are the changes needed? Actually, the case doesn't work at all because of: 1. Literals value has type `UTF8String` 2. It doesn't work for foldable expressions like in the example: ```sql SELECT '%SystemDrive%\Users\John' _FUNC_ '%SystemDrive%\\Users.*'; ``` Screen Shot 2020-02-08 at 22 45 50 ### Does this PR introduce any user-facing change? No ### How was this patch tested? By the `check outputs of expression examples` test from `SQLQuerySuite`. Closes #27502 from MaxGekk/str-regexp-foldable-pattern. Authored-by: Maxim Gekk Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/expressions/regexpExpressions.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f84c476ea5807..f8d328bf601e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -41,9 +41,10 @@ trait StringRegexExpression extends Expression override def dataType: DataType = BooleanType - // try cache the pattern for Literal + // try cache foldable pattern private lazy val cache: Pattern = pattern match { - case Literal(value: String, StringType) => compile(value) + case p: Expression if p.foldable => + compile(p.eval().asInstanceOf[UTF8String].toString) case _ => null } From a6b91d2bf727e175d0e175295001db85647539b1 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 10 Feb 2020 22:16:25 +0100 Subject: [PATCH 008/185] [SPARK-30556][SQL][FOLLOWUP] Reset the status changed in SQLExecution.withThreadLocalCaptured ### What changes were proposed in this pull request? Follow up for #27267, reset the status changed in SQLExecution.withThreadLocalCaptured. ### Why are the changes needed? For code safety. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing UT. Closes #27516 from xuanyuanking/SPARK-30556-follow. Authored-by: Yuanjian Li Signed-off-by: herman --- .../apache/spark/sql/execution/SQLExecution.scala | 12 +++++++++++- .../sql/internal/ExecutorSideSQLConfSuite.scala | 10 ++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 995d94ef5eac7..9f177819f6ea7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -177,9 +177,19 @@ object SQLExecution { val sc = sparkSession.sparkContext val localProps = Utils.cloneProperties(sc.getLocalProperties) Future { + val originalSession = SparkSession.getActiveSession + val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) sc.setLocalProperties(localProps) - body + val res = body + // reset active session and local props. + sc.setLocalProperties(originalLocalProps) + if (originalSession.nonEmpty) { + SparkSession.setActiveSession(originalSession.get) + } else { + SparkSession.clearActiveSession() + } + res }(exec) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 0cc658c499615..46d0c64592a00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.internal +import java.util.UUID + import org.scalatest.Assertions._ import org.apache.spark.{SparkException, SparkFunSuite, TaskContext} @@ -144,16 +146,16 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } // set local configuration and assert - val confValue1 = "e" + val confValue1 = UUID.randomUUID().toString() createDataframe(confKey, confValue1).createOrReplaceTempView("m") spark.sparkContext.setLocalProperty(confKey, confValue1) - assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM m)").collect.size == 1) + assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM m)").collect().length == 1) // change the conf value and assert again - val confValue2 = "f" + val confValue2 = UUID.randomUUID().toString() createDataframe(confKey, confValue2).createOrReplaceTempView("n") spark.sparkContext.setLocalProperty(confKey, confValue2) - assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM n)").collect().size == 1) + assert(sql("SELECT * FROM l WHERE EXISTS (SELECT * FROM n)").collect().length == 1) } } } From e2ebca733ce4366349a5a25fe94a8e31b67d410e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 10 Feb 2020 14:26:14 -0800 Subject: [PATCH 009/185] [SPARK-30779][SS] Fix some API issues found when reviewing Structured Streaming API docs ### What changes were proposed in this pull request? - Fix the scope of `Logging.initializeForcefully` so that it doesn't appear in subclasses' public methods. Right now, `sc.initializeForcefully(false, false)` is allowed to called. - Don't show classes under `org.apache.spark.internal` package in API docs. - Add missing `since` annotation. - Fix the scope of `ArrowUtils` to remove it from the API docs. ### Why are the changes needed? Avoid leaking APIs unintentionally in Spark 3.0.0. ### Does this PR introduce any user-facing change? No. All these changes are to avoid leaking APIs unintentionally in Spark 3.0.0. ### How was this patch tested? Manually generated the API docs and verified the above issues have been fixed. Closes #27528 from zsxwing/audit-ss-apis. Authored-by: Shixiong Zhu Signed-off-by: Xiao Li --- core/src/main/scala/org/apache/spark/internal/Logging.scala | 2 +- project/SparkBuild.scala | 1 + .../sql/connector/read/streaming/ContinuousPartitionReader.java | 2 ++ .../read/streaming/ContinuousPartitionReaderFactory.java | 2 ++ .../spark/sql/connector/read/streaming/ContinuousStream.java | 2 ++ .../spark/sql/connector/read/streaming/MicroBatchStream.java | 2 ++ .../org/apache/spark/sql/connector/read/streaming/Offset.java | 2 ++ .../spark/sql/connector/read/streaming/PartitionOffset.java | 2 ++ .../apache/spark/sql/connector/read/streaming/ReadLimit.java | 1 + .../spark/sql/connector/read/streaming/SparkDataStream.java | 2 ++ .../connector/write/streaming/StreamingDataWriterFactory.java | 2 ++ .../spark/sql/connector/write/streaming/StreamingWrite.java | 2 ++ .../src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala | 2 +- 13 files changed, 22 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala index 2e4846bec2db4..0c1d9635b6535 100644 --- a/core/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala @@ -117,7 +117,7 @@ trait Logging { } // For testing - def initializeForcefully(isInterpreter: Boolean, silent: Boolean): Unit = { + private[spark] def initializeForcefully(isInterpreter: Boolean, silent: Boolean): Unit = { initializeLogging(isInterpreter, silent) } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 707c31d2248eb..9d0af3aa8c1b6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -819,6 +819,7 @@ object Unidoc { .map(_.filterNot(_.getName.contains("$"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/deploy"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/examples"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/internal"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/memory"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/network"))) .map(_.filterNot(f => diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReader.java index 8bd5273bb7d8e..c2ad9ec244a0d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReader.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReader.java @@ -22,6 +22,8 @@ /** * A variation on {@link PartitionReader} for use with continuous streaming processing. + * + * @since 3.0.0 */ @Evolving public interface ContinuousPartitionReader extends PartitionReader { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReaderFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReaderFactory.java index 962864da4aad8..385c6f655440f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReaderFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousPartitionReaderFactory.java @@ -27,6 +27,8 @@ /** * A variation on {@link PartitionReaderFactory} that returns {@link ContinuousPartitionReader} * instead of {@link PartitionReader}. It's used for continuous streaming processing. + * + * @since 3.0.0 */ @Evolving public interface ContinuousPartitionReaderFactory extends PartitionReaderFactory { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousStream.java index ee01a2553ae7a..a84578fe461a3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousStream.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ContinuousStream.java @@ -23,6 +23,8 @@ /** * A {@link SparkDataStream} for streaming queries with continuous mode. + * + * @since 3.0.0 */ @Evolving public interface ContinuousStream extends SparkDataStream { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java index ceab0f75734d3..40ecbf0578ee5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/MicroBatchStream.java @@ -25,6 +25,8 @@ /** * A {@link SparkDataStream} for streaming queries with micro-batch mode. + * + * @since 3.0.0 */ @Evolving public interface MicroBatchStream extends SparkDataStream { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/Offset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/Offset.java index 400de2a659746..efb8ebb684f06 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/Offset.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/Offset.java @@ -25,6 +25,8 @@ * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. + * + * @since 3.0.0 */ @Evolving public abstract class Offset { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/PartitionOffset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/PartitionOffset.java index 35ad3bbde5cbf..faee230467bea 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/PartitionOffset.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/PartitionOffset.java @@ -26,6 +26,8 @@ * provide a method to merge these into a global Offset. * * These offsets must be serializable. + * + * @since 3.0.0 */ @Evolving public interface PartitionOffset extends Serializable { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReadLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReadLimit.java index 121ed1ad116f9..36f6e05e365d9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReadLimit.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/ReadLimit.java @@ -27,6 +27,7 @@ * @see SupportsAdmissionControl#latestOffset(Offset, ReadLimit) * @see ReadAllAvailable * @see ReadMaxRows + * @since 3.0.0 */ @Evolving public interface ReadLimit { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SparkDataStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SparkDataStream.java index 1ba0c25ef4466..95703e255ea4e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SparkDataStream.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/streaming/SparkDataStream.java @@ -25,6 +25,8 @@ * * Data sources should implement concrete data stream interfaces: * {@link MicroBatchStream} and {@link ContinuousStream}. + * + * @since 3.0.0 */ @Evolving public interface SparkDataStream { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java index 9946867e8ea65..0923d07e7e5a3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingDataWriterFactory.java @@ -33,6 +33,8 @@ * Note that, the writer factory will be serialized and sent to executors, then the data writer * will be created on executors and do the actual writing. So this interface must be * serializable and {@link DataWriter} doesn't need to be. + * + * @since 3.0.0 */ @Evolving public interface StreamingDataWriterFactory extends Serializable { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java index 4f930e1c158e5..e3dec3b2ff55e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/streaming/StreamingWrite.java @@ -40,6 +40,8 @@ * do it manually in their Spark applications if they want to retry. * * Please refer to the documentation of commit/abort methods for detailed specifications. + * + * @since 3.0.0 */ @Evolving public interface StreamingWrite { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 2da0d1a51cb29..003ce850c926e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -27,7 +27,7 @@ import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -object ArrowUtils { +private[sql] object ArrowUtils { val rootAllocator = new RootAllocator(Long.MaxValue) From 07a9885f2792be1353f4a923d649e90bc431cb38 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 11 Feb 2020 10:03:01 +0900 Subject: [PATCH 010/185] [SPARK-30777][PYTHON][TESTS] Fix test failures for Pandas >= 1.0.0 ### What changes were proposed in this pull request? Fix PySpark test failures for using Pandas >= 1.0.0. ### Why are the changes needed? Pandas 1.0.0 has recently been released and has API changes that result in PySpark test failures, this PR fixes the broken tests. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Manually tested with Pandas 1.0.1 and PyArrow 0.16.0 Closes #27529 from BryanCutler/pandas-fix-tests-1.0-SPARK-30777. Authored-by: Bryan Cutler Signed-off-by: HyukjinKwon --- python/pyspark/sql/tests/test_arrow.py | 4 ++-- python/pyspark/sql/tests/test_pandas_grouped_map.py | 6 +++--- python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 98f44dfd29da5..004c79f290213 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -297,9 +297,9 @@ def test_createDataFrame_does_not_modify_input(self): # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() # Use a nanosecond value to make sure it is not truncated - pdf.ix[0, '8_timestamp_t'] = pd.Timestamp(1) + pdf.iloc[0, 7] = pd.Timestamp(1) # Integers with nulls will get NaNs filled with 0 and will be casted - pdf.ix[1, '2_int_t'] = None + pdf.iloc[1, 1] = None pdf_copy = pdf.copy(deep=True) self.spark.createDataFrame(pdf, schema=self.schema) self.assertTrue(pdf.equals(pdf_copy)) diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index 51dd07fd7d70c..ff53a0c6f2cf2 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -390,11 +390,11 @@ def rename_pdf(pdf, names): # Function returns a pdf with required column names, but order could be arbitrary using dict def change_col_order(pdf): # Constructing a DataFrame from a dict should result in the same order, - # but use from_items to ensure the pdf column order is different than schema - return pd.DataFrame.from_items([ + # but use OrderedDict to ensure the pdf column order is different than schema + return pd.DataFrame.from_dict(OrderedDict([ ('id', pdf.id), ('u', pdf.v * 2), - ('v', pdf.v)]) + ('v', pdf.v)])) ordered_udf = pandas_udf( change_col_order, diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 974ad560daebf..21679785a769e 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -357,7 +357,7 @@ def test_complex_expressions(self): plus_one(sum_udf(col('v1'))), sum_udf(plus_one(col('v2')))) .sort(['id', '(v % 2)']) - .toPandas().sort_index(by=['id', '(v % 2)'])) + .toPandas().sort_values(by=['id', '(v % 2)'])) expected1 = (df.withColumn('v1', df.v + 1) .withColumn('v2', df.v + 2) @@ -368,7 +368,7 @@ def test_complex_expressions(self): plus_one(sum(col('v1'))), sum(plus_one(col('v2')))) .sort(['id', '(v % 2)']) - .toPandas().sort_index(by=['id', '(v % 2)'])) + .toPandas().sort_values(by=['id', '(v % 2)'])) # Test complex expressions with sql expression, scala pandas UDF and # group aggregate pandas UDF @@ -381,7 +381,7 @@ def test_complex_expressions(self): plus_two(sum_udf(col('v1'))), sum_udf(plus_two(col('v2')))) .sort(['id', '(v % 2)']) - .toPandas().sort_index(by=['id', '(v % 2)'])) + .toPandas().sort_values(by=['id', '(v % 2)'])) expected2 = (df.withColumn('v1', df.v + 1) .withColumn('v2', df.v + 2) @@ -392,7 +392,7 @@ def test_complex_expressions(self): plus_two(sum(col('v1'))), sum(plus_two(col('v2')))) .sort(['id', '(v % 2)']) - .toPandas().sort_index(by=['id', '(v % 2)'])) + .toPandas().sort_values(by=['id', '(v % 2)'])) # Test sequential groupby aggregate result3 = (df.groupby('id') From 2bc765a831d7f15c7971d41c36cfbec1fd898dfd Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Tue, 11 Feb 2020 15:50:03 +0900 Subject: [PATCH 011/185] [SPARK-30756][SQL] Fix `ThriftServerWithSparkContextSuite` on spark-branch-3.0-test-sbt-hadoop-2.7-hive-2.3 ### What changes were proposed in this pull request? This PR tries #26710 (comment) way to fix the test. ### Why are the changes needed? To make the tests pass. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Jenkins will test first, and then `on spark-branch-3.0-test-sbt-hadoop-2.7-hive-2.3` will test it out. Closes #27513 from HyukjinKwon/test-SPARK-30756. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon (cherry picked from commit 8efe367a4ee862b8a85aee8881b0335b34cbba70) Signed-off-by: HyukjinKwon --- project/SparkBuild.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 9d0af3aa8c1b6..1c5c36ea8eae2 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -478,7 +478,8 @@ object SparkParallelTestGrouping { "org.apache.spark.sql.hive.thriftserver.ThriftServerQueryTestSuite", "org.apache.spark.sql.hive.thriftserver.SparkSQLEnvSuite", "org.apache.spark.sql.hive.thriftserver.ui.ThriftServerPageSuite", - "org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2ListenerSuite" + "org.apache.spark.sql.hive.thriftserver.ui.HiveThriftServer2ListenerSuite", + "org.apache.spark.sql.hive.thriftserver.ThriftServerWithSparkContextSuite" ) private val DEFAULT_TEST_GROUP = "default_test_group" From 0045be766b949dff23ed72bd559568f17f645ffe Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Tue, 11 Feb 2020 17:22:08 +0900 Subject: [PATCH 012/185] [SPARK-29462][SQL] The data type of "array()" should be array ### What changes were proposed in this pull request? This brings https://github.com/apache/spark/pull/26324 back. It was reverted basically because, firstly Hive compatibility, and the lack of investigations in other DBMSes and ANSI. - In case of PostgreSQL seems coercing NULL literal to TEXT type. - Presto seems coercing `array() + array(1)` -> array of int. - Hive seems `array() + array(1)` -> array of strings Given that, the design choices have been differently made for some reasons. If we pick one of both, seems coercing to array of int makes much more sense. Another investigation was made offline internally. Seems ANSI SQL 2011, section 6.5 "" states: > If ES is specified, then let ET be the element type determined by the context in which ES appears. The declared type DT of ES is Case: > > a) If ES simply contains ARRAY, then ET ARRAY[0]. > > b) If ES simply contains MULTISET, then ET MULTISET. > > ES is effectively replaced by CAST ( ES AS DT ) From reading other related context, doing it to `NullType`. Given the investigation made, choosing to `null` seems correct, and we have a reference Presto now. Therefore, this PR proposes to bring it back. ### Why are the changes needed? When empty array is created, it should be declared as array. ### Does this PR introduce any user-facing change? Yes, `array()` creates `array`. Now `array(1) + array()` can correctly create `array(1)` instead of `array("1")`. ### How was this patch tested? Tested manually Closes #27521 from HyukjinKwon/SPARK-29462. Lead-authored-by: HyukjinKwon Co-authored-by: Aman Omer Signed-off-by: HyukjinKwon --- docs/sql-migration-guide.md | 2 ++ .../expressions/complexTypeCreator.scala | 11 ++++++++++- .../org/apache/spark/sql/internal/SQLConf.scala | 9 +++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 +++++++++++++---- 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 26eb5838892b4..f98fab5b4c56b 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -215,6 +215,8 @@ license: | For example `SELECT timestamp 'tomorrow';`. - Since Spark 3.0, the `size` function returns `NULL` for the `NULL` input. In Spark version 2.4 and earlier, this function gives `-1` for the same input. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.sizeOfNull` to `true`. + + - Since Spark 3.0, when the `array` function is called without any parameters, it returns an empty array of `NullType`. In Spark version 2.4 and earlier, it returns an empty array of string type. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.arrayDefaultToStringType.enabled` to `true`. - Since Spark 3.0, the interval literal syntax does not allow multiple from-to units anymore. For example, `SELECT INTERVAL '1-1' YEAR TO MONTH '2-2' YEAR TO MONTH'` throws parser exception. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 9ce87a4922c01..7335e305bfe55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -44,10 +45,18 @@ case class CreateArray(children: Seq[Expression]) extends Expression { TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") } + private val defaultElementType: DataType = { + if (SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_DEFAULT_TO_STRING)) { + StringType + } else { + NullType + } + } + override def dataType: ArrayType = { ArrayType( TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(children.map(_.dataType)) - .getOrElse(StringType), + .getOrElse(defaultElementType), containsNull = children.exists(_.nullable)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 64c613611c861..d86f8693e0655 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2007,6 +2007,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_ARRAY_DEFAULT_TO_STRING = + buildConf("spark.sql.legacy.arrayDefaultToStringType.enabled") + .internal() + .doc("When set to true, it returns an empty array of string type when the `array` " + + "function is called without any parameters. Otherwise, it returns an empty " + + "array of `NullType`") + .booleanConf + .createWithDefault(false) + val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL = buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled") .internal() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7fce03658fc16..9e9d8c3e9a7c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3499,12 +3499,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } - test("SPARK-21281 use string types by default if array and map have no argument") { + test("SPARK-21281 use string types by default if map have no argument") { val ds = spark.range(1) var expectedSchema = new StructType() - .add("x", ArrayType(StringType, containsNull = false), nullable = false) - assert(ds.select(array().as("x")).schema == expectedSchema) - expectedSchema = new StructType() .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) assert(ds.select(map().as("x")).schema == expectedSchema) } @@ -3577,6 +3574,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { }.getMessage assert(nonFoldableError.contains("The 'escape' parameter must be a string literal")) } + + test("SPARK-29462: Empty array of NullType for array function with no arguments") { + Seq((true, StringType), (false, NullType)).foreach { + case (arrayDefaultToString, expectedType) => + withSQLConf(SQLConf.LEGACY_ARRAY_DEFAULT_TO_STRING.key -> arrayDefaultToString.toString) { + val schema = spark.range(1).select(array()).schema + assert(schema.nonEmpty && schema.head.dataType.isInstanceOf[ArrayType]) + val actualType = schema.head.dataType.asInstanceOf[ArrayType].elementType + assert(actualType === expectedType) + } + } + } } object DataFrameFunctionsSuite { From b20754d9ee033091e2ef4d5bfa2576f946c9df50 Mon Sep 17 00:00:00 2001 From: root1 Date: Tue, 11 Feb 2020 20:42:02 +0800 Subject: [PATCH 013/185] [SPARK-27545][SQL][DOC] Update the Documentation for CACHE TABLE and UNCACHE TABLE ### What changes were proposed in this pull request? Document updated for `CACHE TABLE` & `UNCACHE TABLE` ### Why are the changes needed? Cache table creates a temp view while caching data using `CACHE TABLE name AS query`. `UNCACHE TABLE` does not remove this temp view. These things were not mentioned in the existing doc for `CACHE TABLE` & `UNCACHE TABLE`. ### Does this PR introduce any user-facing change? Document updated for `CACHE TABLE` & `UNCACHE TABLE` command. ### How was this patch tested? Manually Closes #27090 from iRakson/SPARK-27545. Lead-authored-by: root1 Co-authored-by: iRakson Signed-off-by: Wenchen Fan --- docs/sql-ref-syntax-aux-cache-cache-table.md | 3 ++- docs/sql-ref-syntax-aux-cache-uncache-table.md | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/sql-ref-syntax-aux-cache-cache-table.md b/docs/sql-ref-syntax-aux-cache-cache-table.md index ed6ef973466dd..20ade1961ab0b 100644 --- a/docs/sql-ref-syntax-aux-cache-cache-table.md +++ b/docs/sql-ref-syntax-aux-cache-cache-table.md @@ -20,7 +20,8 @@ license: | --- ### Description -`CACHE TABLE` statement caches contents of a table or output of a query with the given storage level. This reduces scanning of the original files in future queries. +`CACHE TABLE` statement caches contents of a table or output of a query with the given storage level. If a query is cached, then a temp view will be created for this query. +This reduces scanning of the original files in future queries. ### Syntax {% highlight sql %} diff --git a/docs/sql-ref-syntax-aux-cache-uncache-table.md b/docs/sql-ref-syntax-aux-cache-uncache-table.md index e0581d0d213df..69e21c258a333 100644 --- a/docs/sql-ref-syntax-aux-cache-uncache-table.md +++ b/docs/sql-ref-syntax-aux-cache-uncache-table.md @@ -21,11 +21,13 @@ license: | ### Description `UNCACHE TABLE` removes the entries and associated data from the in-memory and/or on-disk cache for a given table or view. The -underlying entries should already have been brought to cache by previous `CACHE TABLE` operation. `UNCACHE TABLE` on a non-existent table throws Exception if `IF EXISTS` is not specified. +underlying entries should already have been brought to cache by previous `CACHE TABLE` operation. `UNCACHE TABLE` on a non-existent table throws an exception if `IF EXISTS` is not specified. + ### Syntax {% highlight sql %} UNCACHE TABLE [ IF EXISTS ] table_identifier {% endhighlight %} + ### Parameters

table_identifier
@@ -37,10 +39,12 @@ UNCACHE TABLE [ IF EXISTS ] table_identifier
+ ### Examples {% highlight sql %} UNCACHE TABLE t1; {% endhighlight %} + ### Related Statements * [CACHE TABLE](sql-ref-syntax-aux-cache-cache-table.html) * [CLEAR CACHE](sql-ref-syntax-aux-cache-clear-cache.html) From f1d0dce4848a53831268c80bf7e1e0f47a1f7612 Mon Sep 17 00:00:00 2001 From: fuwhu Date: Tue, 11 Feb 2020 22:16:44 +0800 Subject: [PATCH 014/185] [MINOR][DOC] Add class document for PruneFileSourcePartitions and PruneHiveTablePartitions ### What changes were proposed in this pull request? Add class document for PruneFileSourcePartitions and PruneHiveTablePartitions. ### Why are the changes needed? To describe these two classes. ### Does this PR introduce any user-facing change? no ### How was this patch tested? no Closes #27535 from fuwhu/SPARK-15616-FOLLOW-UP. Authored-by: fuwhu Signed-off-by: Wenchen Fan --- .../datasources/PruneFileSourcePartitions.scala | 13 +++++++++++++ .../hive/execution/PruneHiveTablePartitions.scala | 8 ++++++++ 2 files changed, 21 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 1ea19c187e51a..a7129fb14d1a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -26,6 +26,19 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan} import org.apache.spark.sql.types.StructType +/** + * Prune the partitions of file source based table using partition filters. Currently, this rule + * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]] and [[DataSourceV2ScanRelation]] + * with [[FileScan]]. + * + * For [[HadoopFsRelation]], the location will be replaced by pruned file index, and corresponding + * statistics will be updated. And the partition filters will be kept in the filters of returned + * logical plan. + * + * For [[DataSourceV2ScanRelation]], both partition filters and data filters will be added to + * its underlying [[FileScan]]. And the partition filters will be removed in the filters of + * returned logical plan. + */ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { private def getPartitionKeyFiltersAndDataFilters( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index a0349f627d107..da6e4c52cf3a7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -30,6 +30,14 @@ import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf /** + * Prune hive table partitions using partition filters on [[HiveTableRelation]]. The pruned + * partitions will be kept in [[HiveTableRelation.prunedPartitions]], and the statistics of + * the hive table relation will be updated based on pruned partitions. + * + * This rule is executed in optimization phase, so the statistics can be updated before physical + * planning, which is useful for some spark strategy, eg. + * [[org.apache.spark.sql.execution.SparkStrategies.JoinSelection]]. + * * TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source. */ private[sql] class PruneHiveTablePartitions(session: SparkSession) From dc66d57e981ac5108e097d4298fa467f0843ffcf Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 11 Feb 2020 09:07:40 -0600 Subject: [PATCH 015/185] [SPARK-30754][SQL] Reuse results of floorDiv in calculations of floorMod in DateTimeUtils ### What changes were proposed in this pull request? In the case of back-to-back calculation of `floorDiv` and `floorMod` with the same arguments, the result of `foorDiv` can be reused in calculation of `floorMod`. The `floorMod` method is defined as the following in Java standard library: ```java public static int floorMod(int x, int y) { int r = x - floorDiv(x, y) * y; return r; } ``` If `floorDiv(x, y)` has been already calculated, it can be reused in `x - floorDiv(x, y) * y`. I propose to modify 2 places in `DateTimeUtils`: 1. `microsToInstant` which is widely used in many date-time functions. `Math.floorMod(us, MICROS_PER_SECOND)` is just replaced by its definition from Java Math library. 2. `truncDate`: `Math.floorMod(oldYear, divider) == 0` is replaced by `Math.floorDiv(oldYear, divider) * divider == oldYear` where `floorDiv(...) * divider` is pre-calculated. ### Why are the changes needed? This reduces the number of arithmetic operations, and can slightly improve performance of date-time functions. ### Does this PR introduce any user-facing change? No ### How was this patch tested? By existing test suites `DateTimeUtilsSuite`, `DateFunctionsSuite` and `DateExpressionsSuite`. Closes #27491 from MaxGekk/opt-microsToInstant. Authored-by: Maxim Gekk Signed-off-by: Sean Owen --- .../spark/sql/catalyst/util/DateTimeUtils.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 8eb560944d4cb..eeae0674166bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -337,7 +337,9 @@ object DateTimeUtils { def microsToInstant(us: Long): Instant = { val secs = Math.floorDiv(us, MICROS_PER_SECOND) - val mos = Math.floorMod(us, MICROS_PER_SECOND) + // Unfolded Math.floorMod(us, MICROS_PER_SECOND) to reuse the result of + // the above calculation of `secs` via `floorDiv`. + val mos = us - secs * MICROS_PER_SECOND Instant.ofEpochSecond(secs, mos * NANOS_PER_MICROS) } @@ -691,11 +693,11 @@ object DateTimeUtils { def truncDate(d: SQLDate, level: Int): SQLDate = { def truncToYearLevel(divider: Int, adjust: Int): SQLDate = { val oldYear = getYear(d) - var newYear = Math.floorDiv(oldYear, divider) - if (adjust > 0 && Math.floorMod(oldYear, divider) == 0) { - newYear -= 1 + var newYear = Math.floorDiv(oldYear, divider) * divider + if (adjust > 0 && newYear == oldYear) { + newYear -= divider } - newYear = newYear * divider + adjust + newYear += adjust localDateToDays(LocalDate.of(newYear, 1, 1)) } level match { From ea626b6acf0de0ff3b0678372f30ba6f84ae2b09 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 12 Feb 2020 00:12:45 +0800 Subject: [PATCH 016/185] [SPARK-30783] Exclude hive-service-rpc ### What changes were proposed in this pull request? Exclude hive-service-rpc from build. ### Why are the changes needed? hive-service-rpc 2.3.6 and spark sql's thrift server module have duplicate classes. Leaving hive-service-rpc 2.3.6 in the class path means that spark can pick up classes defined in hive instead of its thrift server module, which can cause hard to debug runtime errors due to class loading order and compilation errors for applications depend on spark. If you compare hive-service-rpc 2.3.6's jar (https://search.maven.org/remotecontent?filepath=org/apache/hive/hive-service-rpc/2.3.6/hive-service-rpc-2.3.6.jar) and spark thrift server's jar (e.g. https://repository.apache.org/content/groups/snapshots/org/apache/spark/spark-hive-thriftserver_2.12/3.0.0-SNAPSHOT/spark-hive-thriftserver_2.12-3.0.0-20200207.021914-364.jar), you will see that all of classes provided by hive-service-rpc-2.3.6.jar are covered by spark thrift server's jar. https://issues.apache.org/jira/browse/SPARK-30783 has output of jar tf for both jars. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Existing tests. Closes #27533 from yhuai/SPARK-30783. Authored-by: Yin Huai Signed-off-by: Wenchen Fan --- dev/deps/spark-deps-hadoop-2.7-hive-2.3 | 1 - dev/deps/spark-deps-hadoop-3.2-hive-2.3 | 1 - pom.xml | 20 ++++++++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 index 42bdf112efccb..c50cf96dc9065 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 @@ -87,7 +87,6 @@ hive-jdbc/2.3.6//hive-jdbc-2.3.6.jar hive-llap-common/2.3.6//hive-llap-common-2.3.6.jar hive-metastore/2.3.6//hive-metastore-2.3.6.jar hive-serde/2.3.6//hive-serde-2.3.6.jar -hive-service-rpc/2.3.6//hive-service-rpc-2.3.6.jar hive-shims-0.23/2.3.6//hive-shims-0.23-2.3.6.jar hive-shims-common/2.3.6//hive-shims-common-2.3.6.jar hive-shims-scheduler/2.3.6//hive-shims-scheduler-2.3.6.jar diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index 6006fa4b43f42..c37ce7fab36f6 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -86,7 +86,6 @@ hive-jdbc/2.3.6//hive-jdbc-2.3.6.jar hive-llap-common/2.3.6//hive-llap-common-2.3.6.jar hive-metastore/2.3.6//hive-metastore-2.3.6.jar hive-serde/2.3.6//hive-serde-2.3.6.jar -hive-service-rpc/2.3.6//hive-service-rpc-2.3.6.jar hive-shims-0.23/2.3.6//hive-shims-0.23-2.3.6.jar hive-shims-common/2.3.6//hive-shims-common-2.3.6.jar hive-shims-scheduler/2.3.6//hive-shims-scheduler-2.3.6.jar diff --git a/pom.xml b/pom.xml index a8d6ac932bac2..925fa28a291a4 100644 --- a/pom.xml +++ b/pom.xml @@ -1452,6 +1452,11 @@ ${hive.group} hive-service + + + ${hive.group} + hive-service-rpc + ${hive.group} hive-shims @@ -1508,6 +1513,11 @@ ${hive.group} hive-service + + + ${hive.group} + hive-service-rpc + ${hive.group} hive-shims @@ -1761,6 +1771,11 @@ ${hive.group} hive-service + + + ${hive.group} + hive-service-rpc + ${hive.group} hive-shims @@ -1911,6 +1926,11 @@ groovy-all + + + ${hive.group} + hive-service-rpc + org.apache.parquet From 99bd59fe29a87bb70485db536b0ae676e7a9d42e Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Tue, 11 Feb 2020 09:55:02 -0800 Subject: [PATCH 017/185] [SPARK-29462][SQL][DOCS] Add some more context and details in 'spark.sql.defaultUrlStreamHandlerFactory.enabled' documentation ### What changes were proposed in this pull request? This PR adds some more information and context to `spark.sql.defaultUrlStreamHandlerFactory.enabled`. ### Why are the changes needed? It is a bit difficult to understand the documentation of `spark.sql.defaultUrlStreamHandlerFactory.enabled`. ### Does this PR introduce any user-facing change? Nope, internal doc only fix. ### How was this patch tested? Nope. I only tested linter. Closes #27541 from HyukjinKwon/SPARK-29462-followup. Authored-by: HyukjinKwon Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/internal/StaticSQLConf.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index 6bc752260a893..563e51ed597b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -172,7 +172,13 @@ object StaticSQLConf { val DEFAULT_URL_STREAM_HANDLER_FACTORY_ENABLED = buildStaticConf("spark.sql.defaultUrlStreamHandlerFactory.enabled") - .doc("When true, set FsUrlStreamHandlerFactory to support ADD JAR against HDFS locations") + .doc( + "When true, register Hadoop's FsUrlStreamHandlerFactory to support " + + "ADD JAR against HDFS locations. " + + "It should be disabled when a different stream protocol handler should be registered " + + "to support a particular protocol type, or if Hadoop's FsUrlStreamHandlerFactory " + + "conflicts with other protocol types such as `http` or `https`. See also SPARK-25694 " + + "and HADOOP-14598.") .internal() .booleanConf .createWithDefault(true) From 45db48e2d29359591a4ebc3db4625dd2158e446e Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 11 Feb 2020 10:15:34 -0800 Subject: [PATCH 018/185] Revert "[SPARK-30625][SQL] Support `escape` as third parameter of the `like` function ### What changes were proposed in this pull request? In the PR, I propose to revert the commit 8aebc80e0e67bcb1aa300b8c8b1a209159237632. ### Why are the changes needed? See the concerns https://github.com/apache/spark/pull/27355#issuecomment-584344438 ### Does this PR introduce any user-facing change? No ### How was this patch tested? By existing test suites. Closes #27531 from MaxGekk/revert-like-3-args. Authored-by: Maxim Gekk Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/dsl/package.scala | 2 +- .../expressions/regexpExpressions.scala | 85 ++++++------------- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../spark/sql/DataFrameFunctionsSuite.scala | 15 ---- 4 files changed, 31 insertions(+), 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 40998080bc4e3..b4a8bafe22dfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -99,7 +99,7 @@ package object dsl { } def like(other: Expression, escapeChar: Char = '\\'): Expression = - Like(expr, other, Literal(escapeChar.toString)) + Like(expr, other, escapeChar) def rlike(other: Expression): Expression = RLike(expr, other) def contains(other: Expression): Expression = Contains(expr, other) def startsWith(other: Expression): Expression = StartsWith(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f8d328bf601e4..e5ee0edfcf79b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -22,7 +22,6 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.text.StringEscapeUtils -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} @@ -30,19 +29,17 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends Expression +abstract class StringRegexExpression extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - def str: Expression - def pattern: Expression - def escape(v: String): String def matches(regex: Pattern, str: String): Boolean override def dataType: DataType = BooleanType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache foldable pattern - private lazy val cache: Pattern = pattern match { + private lazy val cache: Pattern = right match { case p: Expression if p.foldable => compile(p.eval().asInstanceOf[UTF8String].toString) case _ => null @@ -55,9 +52,10 @@ trait StringRegexExpression extends Expression Pattern.compile(escape(str)) } - def nullSafeMatch(input1: Any, input2: Any): Any = { - val s = input2.asInstanceOf[UTF8String].toString - val regex = if (cache == null) compile(s) else cache + protected def pattern(str: String) = if (cache == null) compile(str) else cache + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val regex = pattern(input2.asInstanceOf[UTF8String].toString) if(regex == null) { null } else { @@ -65,7 +63,7 @@ trait StringRegexExpression extends Expression } } - override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}" + override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" } // scalastyle:off line.contains.tab @@ -110,65 +108,46 @@ trait StringRegexExpression extends Expression true > SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/'; true - > SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_'); - true """, note = """ Use RLIKE to match with standard regular expressions. """, since = "1.0.0") // scalastyle:on line.contains.tab -case class Like(str: Expression, pattern: Expression, escape: Expression) - extends TernaryExpression with StringRegexExpression { - - def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\")) - - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = Seq(str, pattern, escape) +case class Like(left: Expression, right: Expression, escapeChar: Char) + extends StringRegexExpression { - private lazy val escapeChar: Char = if (escape.foldable) { - escape.eval() match { - case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0) - case s => throw new AnalysisException( - s"The 'escape' parameter must be a string literal of one char but it is $s.") - } - } else { - throw new AnalysisException("The 'escape' parameter must be a string literal.") - } + def this(left: Expression, right: Expression) = this(left, right, '\\') override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() override def toString: String = escapeChar match { - case '\\' => s"$str LIKE $pattern" - case c => s"$str LIKE $pattern ESCAPE '$c'" - } - - protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - nullSafeMatch(input1, input2) + case '\\' => s"$left LIKE $right" + case c => s"$left LIKE $right ESCAPE '$c'" } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - if (pattern.foldable) { - val patternVal = pattern.eval() - if (patternVal != null) { + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { val regexStr = - StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString())) - val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern", + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + val pattern = ctx.addMutableState(patternClass, "patternLike", v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = str.genCode(ctx) + val eval = left.genCode(ctx) ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches(); + ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } """) } else { @@ -178,8 +157,8 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) """) } } else { - val patternStr = ctx.freshName("patternStr") - val compiledPattern = ctx.freshName("compiledPattern") + val pattern = ctx.freshName("pattern") + val rightStr = ctx.freshName("rightStr") // We need double escape to avoid org.codehaus.commons.compiler.CompileException. // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. // '\"' will cause exception 'Line break in literal not allowed'. @@ -188,12 +167,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) } else { escapeChar } - nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String $patternStr = $eval2.toString(); - $patternClass $compiledPattern = $patternClass.compile( - $escapeFunc($patternStr, '$newEscapeChar')); - ${ev.value} = $compiledPattern.matcher($eval1.toString()).matches(); + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile( + $escapeFunc($rightStr, '$newEscapeChar')); + ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) } @@ -232,20 +211,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) """, since = "1.0.0") // scalastyle:on line.contains.tab -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { - - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - override def str: Expression = left - override def pattern: Expression = right +case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" - protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6fc65e14868e0..62e568587fcc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging throw new ParseException("Invalid escape string." + "Escape string must contains only one character.", ctx) } - str + str.charAt(0) }.getOrElse('\\') - invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar))) + invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar)) case SqlBaseParser.RLIKE => invertIfNotDefined(RLike(e, expression(ctx.pattern))) case SqlBaseParser.NULL if ctx.NOT != null => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9e9d8c3e9a7c5..6012678341ccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3560,21 +3560,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(1))) } - test("the like function with the escape parameter") { - val df = Seq(("abc", "a_c", "!")).toDF("str", "pattern", "escape") - checkAnswer(df.selectExpr("like(str, pattern, '@')"), Row(true)) - - val longEscapeError = intercept[AnalysisException] { - df.selectExpr("like(str, pattern, '@%')").collect() - }.getMessage - assert(longEscapeError.contains("The 'escape' parameter must be a string literal of one char")) - - val nonFoldableError = intercept[AnalysisException] { - df.selectExpr("like(str, pattern, escape)").collect() - }.getMessage - assert(nonFoldableError.contains("The 'escape' parameter must be a string literal")) - } - test("SPARK-29462: Empty array of NullType for array function with no arguments") { Seq((true, StringType), (false, NullType)).foreach { case (arrayDefaultToString, expectedType) => From b25359cca3190f6a34dce3c3e49c4d2a80e88bdc Mon Sep 17 00:00:00 2001 From: herman Date: Wed, 12 Feb 2020 10:48:29 +0900 Subject: [PATCH 019/185] [SPARK-30780][SQL] Empty LocalTableScan should use RDD without partitions ### What changes were proposed in this pull request? This is a small follow-up for https://github.com/apache/spark/pull/27400. This PR makes an empty `LocalTableScanExec` return an `RDD` without partitions. ### Why are the changes needed? It is a bit unexpected that the RDD contains partitions if there is not work to do. It also can save a bit of work when this is used in a more complex plan. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Added test to `SparkPlanSuite`. Closes #27530 from hvanhovell/SPARK-30780. Authored-by: herman Signed-off-by: HyukjinKwon --- .../spark/sql/execution/LocalTableScanExec.scala | 12 ++++++++---- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../apache/spark/sql/execution/SparkPlanSuite.scala | 4 ++++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index 1b5115f2e29a3..b452213cd6cc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -45,10 +45,14 @@ case class LocalTableScanExec( } } - private lazy val numParallelism: Int = math.min(math.max(unsafeRows.length, 1), - sqlContext.sparkContext.defaultParallelism) - - private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows, numParallelism) + @transient private lazy val rdd: RDD[InternalRow] = { + if (rows.isEmpty) { + sqlContext.sparkContext.emptyRDD + } else { + val numSlices = math.min(unsafeRows.length, sqlContext.sparkContext.defaultParallelism) + sqlContext.sparkContext.parallelize(unsafeRows, numSlices) + } + } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d2d58a83ded5d..694e576fcded4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -330,7 +330,7 @@ class DataFrameSuite extends QueryTest testData.select("key").coalesce(1).select("key"), testData.select("key").collect().toSeq) - assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 0) } test("convert $\"attribute name\" into unresolved attribute") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala index e3bc414516c04..56fff1107ae39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala @@ -84,4 +84,8 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-30780 empty LocalTableScan should use RDD without partitions") { + assert(LocalTableScanExec(Nil, Nil).execute().getNumPartitions == 0) + } } From aa6a60530e63ab3bb8b117f8738973d1b26a2cc7 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Wed, 12 Feb 2020 10:49:46 +0900 Subject: [PATCH 020/185] [SPARK-30722][PYTHON][DOCS] Update documentation for Pandas UDF with Python type hints ### What changes were proposed in this pull request? This PR targets to document the Pandas UDF redesign with type hints introduced at SPARK-28264. Mostly self-describing; however, there are few things to note for reviewers. 1. This PR replace the existing documentation of pandas UDFs to the newer redesign to promote the Python type hints. I added some words that Spark 3.0 still keeps the compatibility though. 2. This PR proposes to name non-pandas UDFs as "Pandas Function API" 3. SCALAR_ITER become two separate sections to reduce confusion: - `Iterator[pd.Series]` -> `Iterator[pd.Series]` - `Iterator[Tuple[pd.Series, ...]]` -> `Iterator[pd.Series]` 4. I removed some examples that look overkill to me. 5. I also removed some information in the doc, that seems duplicating or too much. ### Why are the changes needed? To document new redesign in pandas UDF. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing tests should cover. Closes #27466 from HyukjinKwon/SPARK-30722. Authored-by: HyukjinKwon Signed-off-by: HyukjinKwon --- dev/sparktestsupport/modules.py | 1 - docs/sql-pyspark-pandas-with-arrow.md | 233 +++++++---- examples/src/main/python/sql/arrow.py | 258 ++++++------ python/pyspark/sql/pandas/functions.py | 538 +++++++++++-------------- python/pyspark/sql/pandas/group_ops.py | 99 ++++- python/pyspark/sql/pandas/map_ops.py | 6 +- python/pyspark/sql/udf.py | 16 +- 7 files changed, 609 insertions(+), 542 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 40f2ca288d694..391e4bbe1b1f0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -364,7 +364,6 @@ def __hash__(self): "pyspark.sql.avro.functions", "pyspark.sql.pandas.conversion", "pyspark.sql.pandas.map_ops", - "pyspark.sql.pandas.functions", "pyspark.sql.pandas.group_ops", "pyspark.sql.pandas.types", "pyspark.sql.pandas.serializers", diff --git a/docs/sql-pyspark-pandas-with-arrow.md b/docs/sql-pyspark-pandas-with-arrow.md index 7eb8a74547f70..92a515746b607 100644 --- a/docs/sql-pyspark-pandas-with-arrow.md +++ b/docs/sql-pyspark-pandas-with-arrow.md @@ -35,7 +35,7 @@ working with Arrow-enabled data. If you install PySpark using pip, then PyArrow can be brought in as an extra dependency of the SQL module with the command `pip install pyspark[sql]`. Otherwise, you must ensure that PyArrow -is installed and available on all cluster nodes. The current supported version is 0.12.1. +is installed and available on all cluster nodes. The current supported version is 0.15.1+. You can install using pip or conda from the conda-forge channel. See PyArrow [installation](https://arrow.apache.org/docs/python/install.html) for details. @@ -65,132 +65,216 @@ Spark will fall back to create the DataFrame without Arrow. ## Pandas UDFs (a.k.a. Vectorized UDFs) -Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer data and -Pandas to work with the data. A Pandas UDF is defined using the keyword `pandas_udf` as a decorator -or to wrap the function, no additional configuration is required. Currently, there are two types of -Pandas UDF: Scalar and Grouped Map. +Pandas UDFs are user defined functions that are executed by Spark using +Arrow to transfer data and Pandas to work with the data, which allows vectorized operations. A Pandas +UDF is defined using the `pandas_udf` as a decorator or to wrap the function, and no additional +configuration is required. A Pandas UDF behaves as a regular PySpark function API in general. -### Scalar +Before Spark 3.0, Pandas UDFs used to be defined with `PandasUDFType`. From Spark 3.0 +with Python 3.6+, you can also use [Python type hints](https://www.python.org/dev/peps/pep-0484). +Using Python type hints are preferred and using `PandasUDFType` will be deprecated in +the future release. -Scalar Pandas UDFs are used for vectorizing scalar operations. They can be used with functions such -as `select` and `withColumn`. The Python function should take `pandas.Series` as inputs and return -a `pandas.Series` of the same length. Internally, Spark will execute a Pandas UDF by splitting -columns into batches and calling the function for each batch as a subset of the data, then -concatenating the results together. +Note that the type hint should use `pandas.Series` in all cases but there is one variant +that `pandas.DataFrame` should be used for its input or output type hint instead when the input +or output column is of `StructType`. The following example shows a Pandas UDF which takes long +column, string column and struct column, and outputs a struct column. It requires the function to +specify the type hints of `pandas.Series` and `pandas.DataFrame` as below: -The following example shows how to create a scalar Pandas UDF that computes the product of 2 columns. +

+

+
+{% include_example ser_to_frame_pandas_udf python/sql/arrow.py %} +
+
+

+ +In the following sections, it describes the cominations of the supported type hints. For simplicity, +`pandas.DataFrame` variant is omitted. + +### Series to Series + +The type hint can be expressed as `pandas.Series`, ... -> `pandas.Series`. + +By using `pandas_udf` with the function having such type hints, it creates a Pandas UDF where the given +function takes one or more `pandas.Series` and outputs one `pandas.Series`. The output of the function should +always be of the same length as the input. Internally, PySpark will execute a Pandas UDF by splitting +columns into batches and calling the function for each batch as a subset of the data, then concatenating +the results together. + +The following example shows how to create this Pandas UDF that computes the product of 2 columns.
-{% include_example scalar_pandas_udf python/sql/arrow.py %} +{% include_example ser_to_ser_pandas_udf python/sql/arrow.py %}
-### Scalar Iterator +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) + +### Iterator of Series to Iterator of Series + +The type hint can be expressed as `Iterator[pandas.Series]` -> `Iterator[pandas.Series]`. + +By using `pandas_udf` with the function having such type hints, it creates a Pandas UDF where the given +function takes an iterator of `pandas.Series` and outputs an iterator of `pandas.Series`. The output of each +series from the function should always be of the same length as the input. In this case, the created +Pandas UDF requires one input column when the Pandas UDF is called. To use multiple input columns, +a different type hint is required. See Iterator of Multiple Series to Iterator of Series. + +It is useful when the UDF execution requires initializing some states although internally it works +identically as Series to Series case. The pseudocode below illustrates the example. + +{% highlight python %} +@pandas_udf("long") +def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + # Do some expensive initialization with a state + state = very_expensive_initialization() + for x in iterator: + # Use that state for whole iterator. + yield calculate_with_state(x, state) -Scalar iterator (`SCALAR_ITER`) Pandas UDF is the same as scalar Pandas UDF above except that the -underlying Python function takes an iterator of batches as input instead of a single batch and, -instead of returning a single output batch, it yields output batches or returns an iterator of -output batches. -It is useful when the UDF execution requires initializing some states, e.g., loading an machine -learning model file to apply inference to every input batch. +df.select(calculate("value")).show() +{% endhighlight %} -The following example shows how to create scalar iterator Pandas UDFs: +The following example shows how to create this Pandas UDF:
-{% include_example scalar_iter_pandas_udf python/sql/arrow.py %} +{% include_example iter_ser_to_iter_ser_pandas_udf python/sql/arrow.py %}
-### Grouped Map -Grouped map Pandas UDFs are used with `groupBy().apply()` which implements the "split-apply-combine" pattern. -Split-apply-combine consists of three steps: -* Split the data into groups by using `DataFrame.groupBy`. -* Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The - input data contains all the rows and columns for each group. -* Combine the results into a new `DataFrame`. +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) -To use `groupBy().apply()`, the user needs to define the following: -* A Python function that defines the computation for each group. -* A `StructType` object or a string that defines the schema of the output `DataFrame`. +### Iterator of Multiple Series to Iterator of Series -The column labels of the returned `pandas.DataFrame` must either match the field names in the -defined output schema if specified as strings, or match the field data types by position if not -strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) -on how to label columns when constructing a `pandas.DataFrame`. +The type hint can be expressed as `Iterator[Tuple[pandas.Series, ...]]` -> `Iterator[pandas.Series]`. -Note that all data for a group will be loaded into memory before the function is applied. This can -lead to out of memory exceptions, especially if the group sizes are skewed. The configuration for -[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user -to ensure that the grouped data will fit into the available memory. +By using `pandas_udf` with the function having such type hints, it creates a Pandas UDF where the +given function takes an iterator of a tuple of multiple `pandas.Series` and outputs an iterator of `pandas.Series`. +In this case, the created pandas UDF requires multiple input columns as many as the series in the tuple +when the Pandas UDF is called. It works identically as Iterator of Series to Iterator of Series case except the parameter difference. -The following example shows how to use `groupby().apply()` to subtract the mean from each value in the group. +The following example shows how to create this Pandas UDF:
-{% include_example grouped_map_pandas_udf python/sql/arrow.py %} +{% include_example iter_sers_to_iter_ser_pandas_udf python/sql/arrow.py %}
-For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and -[`pyspark.sql.GroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.apply). +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) + +### Series to Scalar -### Grouped Aggregate +The type hint can be expressed as `pandas.Series`, ... -> `Any`. -Grouped aggregate Pandas UDFs are similar to Spark aggregate functions. Grouped aggregate Pandas UDFs are used with `groupBy().agg()` and -[`pyspark.sql.Window`](api/python/pyspark.sql.html#pyspark.sql.Window). It defines an aggregation from one or more `pandas.Series` -to a scalar value, where each `pandas.Series` represents a column within the group or window. +By using `pandas_udf` with the function having such type hints, it creates a Pandas UDF similar +to PySpark's aggregate functions. The given function takes `pandas.Series` and returns a scalar value. +The return type should be a primitive data type, and the returned scalar can be either a python +primitive type, e.g., `int` or `float` or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. +`Any` should ideally be a specific scalar type accordingly. -Note that this type of UDF does not support partial aggregation and all data for a group or window will be loaded into memory. Also, -only unbounded window is supported with Grouped aggregate Pandas UDFs currently. +This UDF can be also used with `groupBy().agg()` and [`pyspark.sql.Window`](api/python/pyspark.sql.html#pyspark.sql.Window). +It defines an aggregation from one or more `pandas.Series` to a scalar value, where each `pandas.Series` +represents a column within the group or window. -The following example shows how to use this type of UDF to compute mean with groupBy and window operations: +Note that this type of UDF does not support partial aggregation and all data for a group or window +will be loaded into memory. Also, only unbounded window is supported with Grouped aggregate Pandas +UDFs currently. The following example shows how to use this type of UDF to compute mean with a group-by +and window operations:
-{% include_example grouped_agg_pandas_udf python/sql/arrow.py %} +{% include_example ser_to_scalar_pandas_udf python/sql/arrow.py %}
For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) -### Map Iterator +## Pandas Function APIs + +Pandas function APIs can directly apply a Python native function against the whole DataFrame by +using Pandas instances. Internally it works similarly with Pandas UDFs by Spark using Arrow to transfer +data and Pandas to work with the data, which allows vectorized operations. A Pandas function API behaves +as a regular API under PySpark `DataFrame` in general. + +From Spark 3.0, Grouped map pandas UDF is now categorized as a separate Pandas Function API, +`DataFrame.groupby().applyInPandas()`. It is still possible to use it with `PandasUDFType` +and `DataFrame.groupby().apply()` as it was; however, it is preferred to use +`DataFrame.groupby().applyInPandas()` directly. Using `PandasUDFType` will be deprecated +in the future. + +### Grouped Map + +Grouped map operations with Pandas instances are supported by `DataFrame.groupby().applyInPandas()` +which requires a Python function that takes a `pandas.DataFrame` and return another `pandas.DataFrame`. +It maps each group to each `pandas.DataFrame` in the Python function. + +This API implements the "split-apply-combine" pattern which consists of three steps: +* Split the data into groups by using `DataFrame.groupBy`. +* Apply a function on each group. The input and output of the function are both `pandas.DataFrame`. The + input data contains all the rows and columns for each group. +* Combine the results into a new PySpark `DataFrame`. -Map iterator Pandas UDFs are used to transform data with an iterator of batches. Map iterator -Pandas UDFs can be used with -[`pyspark.sql.DataFrame.mapInPandas`](api/python/pyspark.sql.html#pyspark.sql.DataFrame.mapInPandas). -It defines a map function that transforms an iterator of `pandas.DataFrame` to another. +To use `groupBy().applyInPandas()`, the user needs to define the following: +* A Python function that defines the computation for each group. +* A `StructType` object or a string that defines the schema of the output PySpark `DataFrame`. + +The column labels of the returned `pandas.DataFrame` must either match the field names in the +defined output schema if specified as strings, or match the field data types by position if not +strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) +on how to label columns when constructing a `pandas.DataFrame`. -It can return the output of arbitrary length in contrast to the scalar Pandas UDF. It maps an iterator of `pandas.DataFrame`s, -that represents the current `DataFrame`, using the map iterator UDF and returns the result as a `DataFrame`. +Note that all data for a group will be loaded into memory before the function is applied. This can +lead to out of memory exceptions, especially if the group sizes are skewed. The configuration for +[maxRecordsPerBatch](#setting-arrow-batch-size) is not applied on groups and it is up to the user +to ensure that the grouped data will fit into the available memory. -The following example shows how to create map iterator Pandas UDFs: +The following example shows how to use `groupby().applyInPandas()` to subtract the mean from each value +in the group.
-{% include_example map_iter_pandas_udf python/sql/arrow.py %} +{% include_example grouped_apply_in_pandas python/sql/arrow.py %}
-For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and -[`pyspark.sql.DataFrame.mapsInPandas`](api/python/pyspark.sql.html#pyspark.sql.DataFrame.mapInPandas). +For detailed usage, please see [`pyspark.sql.GroupedData.applyInPandas`](api/python/pyspark.sql.html#pyspark.sql.GroupedData.applyInPandas). +### Map + +Map operations with Pandas instances are supported by `DataFrame.mapInPandas()` which maps an iterator +of `pandas.DataFrame`s to another iterator of `pandas.DataFrame`s that represents the current +PySpark `DataFrame` and returns the result as a PySpark `DataFrame`. The functions takes and outputs +an iterator of `pandas.DataFrame`. It can return the output of arbitrary length in contrast to some +Pandas UDFs although internally it works similarly with Series to Series Pandas UDF. + +The following example shows how to use `mapInPandas()`: + +
+
+{% include_example map_in_pandas python/sql/arrow.py %} +
+
-### Cogrouped Map +For detailed usage, please see [`pyspark.sql.DataFrame.mapsInPandas`](api/python/pyspark.sql.html#pyspark.sql.DataFrame.mapInPandas). -Cogrouped map Pandas UDFs allow two DataFrames to be cogrouped by a common key and then a python function applied to -each cogroup. They are used with `groupBy().cogroup().apply()` which consists of the following steps: +### Co-grouped Map +Co-grouped map operations with Pandas instances are supported by `DataFrame.groupby().cogroup().applyInPandas()` which +allows two PySpark `DataFrame`s to be cogrouped by a common key and then a Python function applied to each +cogroup. It consists of the following steps: * Shuffle the data such that the groups of each dataframe which share a key are cogrouped together. -* Apply a function to each cogroup. The input of the function is two `pandas.DataFrame` (with an optional Tuple -representing the key). The output of the function is a `pandas.DataFrame`. -* Combine the pandas.DataFrames from all groups into a new `DataFrame`. +* Apply a function to each cogroup. The input of the function is two `pandas.DataFrame` (with an optional tuple +representing the key). The output of the function is a `pandas.DataFrame`. +* Combine the `pandas.DataFrame`s from all groups into a new PySpark `DataFrame`. -To use `groupBy().cogroup().apply()`, the user needs to define the following: +To use `groupBy().cogroup().applyInPandas()`, the user needs to define the following: * A Python function that defines the computation for each cogroup. -* A `StructType` object or a string that defines the schema of the output `DataFrame`. +* A `StructType` object or a string that defines the schema of the output PySpark `DataFrame`. The column labels of the returned `pandas.DataFrame` must either match the field names in the defined output schema if specified as strings, or match the field data types by position if not @@ -201,16 +285,15 @@ Note that all data for a cogroup will be loaded into memory before the function memory exceptions, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) is not applied and it is up to the user to ensure that the cogrouped data will fit into the available memory. -The following example shows how to use `groupby().cogroup().apply()` to perform an asof join between two datasets. +The following example shows how to use `groupby().cogroup().applyInPandas()` to perform an asof join between two datasets.
-{% include_example cogrouped_map_pandas_udf python/sql/arrow.py %} +{% include_example cogrouped_apply_in_pandas python/sql/arrow.py %}
-For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and -[`pyspark.sql.CoGroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.CoGroupedData.apply). +For detailed usage, please see [`pyspark.sql.PandasCogroupedOps.applyInPandas()`](api/python/pyspark.sql.html#pyspark.sql.PandasCogroupedOps.applyInPandas). ## Usage Notes diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index 1c983172d36ef..b7d8467172fab 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -23,12 +23,19 @@ from __future__ import print_function +import sys + from pyspark.sql import SparkSession from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version require_minimum_pandas_version() require_minimum_pyarrow_version() +if sys.version_info < (3, 6): + raise Exception( + "Running this example file requires Python 3.6+; however, " + "your Python version was:\n %s" % sys.version) + def dataframe_with_arrow_example(spark): # $example on:dataframe_with_arrow$ @@ -50,15 +57,45 @@ def dataframe_with_arrow_example(spark): print("Pandas DataFrame result statistics:\n%s\n" % str(result_pdf.describe())) -def scalar_pandas_udf_example(spark): - # $example on:scalar_pandas_udf$ +def ser_to_frame_pandas_udf_example(spark): + # $example on:ser_to_frame_pandas_udf$ + import pandas as pd + + from pyspark.sql.functions import pandas_udf + + @pandas_udf("col1 string, col2 long") + def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame: + s3['col2'] = s1 + s2.str.len() + return s3 + + # Create a Spark DataFrame that has three columns including a sturct column. + df = spark.createDataFrame( + [[1, "a string", ("a nested string",)]], + "long_col long, string_col string, struct_col struct") + + df.printSchema() + # root + # |-- long_column: long (nullable = true) + # |-- string_column: string (nullable = true) + # |-- struct_column: struct (nullable = true) + # | |-- col1: string (nullable = true) + + df.select(func("long_col", "string_col", "struct_col")).printSchema() + # |-- func(long_col, string_col, struct_col): struct (nullable = true) + # | |-- col1: string (nullable = true) + # | |-- col2: long (nullable = true) + # $example off:ser_to_frame_pandas_udf$$ + + +def ser_to_ser_pandas_udf_example(spark): + # $example on:ser_to_ser_pandas_udf$ import pandas as pd from pyspark.sql.functions import col, pandas_udf from pyspark.sql.types import LongType # Declare the function and create the UDF - def multiply_func(a, b): + def multiply_func(a: pd.Series, b: pd.Series) -> pd.Series: return a * b multiply = pandas_udf(multiply_func, returnType=LongType()) @@ -83,26 +120,27 @@ def multiply_func(a, b): # | 4| # | 9| # +-------------------+ - # $example off:scalar_pandas_udf$ + # $example off:ser_to_ser_pandas_udf$ -def scalar_iter_pandas_udf_example(spark): - # $example on:scalar_iter_pandas_udf$ +def iter_ser_to_iter_ser_pandas_udf_example(spark): + # $example on:iter_ser_to_iter_ser_pandas_udf$ + from typing import Iterator + import pandas as pd - from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType + from pyspark.sql.functions import pandas_udf pdf = pd.DataFrame([1, 2, 3], columns=["x"]) df = spark.createDataFrame(pdf) - # When the UDF is called with a single column that is not StructType, - # the input to the underlying function is an iterator of pd.Series. - @pandas_udf("long", PandasUDFType.SCALAR_ITER) - def plus_one(batch_iter): - for x in batch_iter: + # Declare the function and create the UDF + @pandas_udf("long") + def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + for x in iterator: yield x + 1 - df.select(plus_one(col("x"))).show() + df.select(plus_one("x")).show() # +-----------+ # |plus_one(x)| # +-----------+ @@ -110,15 +148,28 @@ def plus_one(batch_iter): # | 3| # | 4| # +-----------+ + # $example off:iter_ser_to_iter_ser_pandas_udf$ + + +def iter_sers_to_iter_ser_pandas_udf_example(spark): + # $example on:iter_sers_to_iter_ser_pandas_udf$ + from typing import Iterator, Tuple + + import pandas as pd - # When the UDF is called with more than one columns, - # the input to the underlying function is an iterator of pd.Series tuple. - @pandas_udf("long", PandasUDFType.SCALAR_ITER) - def multiply_two_cols(batch_iter): - for a, b in batch_iter: + from pyspark.sql.functions import pandas_udf + + pdf = pd.DataFrame([1, 2, 3], columns=["x"]) + df = spark.createDataFrame(pdf) + + # Declare the function and create the UDF + @pandas_udf("long") + def multiply_two_cols( + iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]: + for a, b in iterator: yield a * b - df.select(multiply_two_cols(col("x"), col("x"))).show() + df.select(multiply_two_cols("x", "x")).show() # +-----------------------+ # |multiply_two_cols(x, x)| # +-----------------------+ @@ -126,92 +177,32 @@ def multiply_two_cols(batch_iter): # | 4| # | 9| # +-----------------------+ + # $example off:iter_sers_to_iter_ser_pandas_udf$ - # When the UDF is called with a single column that is StructType, - # the input to the underlying function is an iterator of pd.DataFrame. - @pandas_udf("long", PandasUDFType.SCALAR_ITER) - def multiply_two_nested_cols(pdf_iter): - for pdf in pdf_iter: - yield pdf["a"] * pdf["b"] - - df.select( - multiply_two_nested_cols( - struct(col("x").alias("a"), col("x").alias("b")) - ).alias("y") - ).show() - # +---+ - # | y| - # +---+ - # | 1| - # | 4| - # | 9| - # +---+ - - # In the UDF, you can initialize some states before processing batches. - # Wrap your code with try/finally or use context managers to ensure - # the release of resources at the end. - y_bc = spark.sparkContext.broadcast(1) - - @pandas_udf("long", PandasUDFType.SCALAR_ITER) - def plus_y(batch_iter): - y = y_bc.value # initialize states - try: - for x in batch_iter: - yield x + y - finally: - pass # release resources here, if any - - df.select(plus_y(col("x"))).show() - # +---------+ - # |plus_y(x)| - # +---------+ - # | 2| - # | 3| - # | 4| - # +---------+ - # $example off:scalar_iter_pandas_udf$ - - -def grouped_map_pandas_udf_example(spark): - # $example on:grouped_map_pandas_udf$ - from pyspark.sql.functions import pandas_udf, PandasUDFType - - df = spark.createDataFrame( - [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ("id", "v")) - - @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) - def subtract_mean(pdf): - # pdf is a pandas.DataFrame - v = pdf.v - return pdf.assign(v=v - v.mean()) - - df.groupby("id").apply(subtract_mean).show() - # +---+----+ - # | id| v| - # +---+----+ - # | 1|-0.5| - # | 1| 0.5| - # | 2|-3.0| - # | 2|-1.0| - # | 2| 4.0| - # +---+----+ - # $example off:grouped_map_pandas_udf$ +def ser_to_scalar_pandas_udf_example(spark): + # $example on:ser_to_scalar_pandas_udf$ + import pandas as pd -def grouped_agg_pandas_udf_example(spark): - # $example on:grouped_agg_pandas_udf$ - from pyspark.sql.functions import pandas_udf, PandasUDFType + from pyspark.sql.functions import pandas_udf from pyspark.sql import Window df = spark.createDataFrame( [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) - @pandas_udf("double", PandasUDFType.GROUPED_AGG) - def mean_udf(v): + # Declare the function and create the UDF + @pandas_udf("double") + def mean_udf(v: pd.Series) -> float: return v.mean() + df.select(mean_udf(df['v'])).show() + # +-----------+ + # |mean_udf(v)| + # +-----------+ + # | 4.2| + # +-----------+ + df.groupby("id").agg(mean_udf(df['v'])).show() # +---+-----------+ # | id|mean_udf(v)| @@ -233,37 +224,54 @@ def mean_udf(v): # | 2| 5.0| 6.0| # | 2|10.0| 6.0| # +---+----+------+ - # $example off:grouped_agg_pandas_udf$ + # $example off:ser_to_scalar_pandas_udf$ -def map_iter_pandas_udf_example(spark): - # $example on:map_iter_pandas_udf$ - import pandas as pd +def grouped_apply_in_pandas_example(spark): + # $example on:grouped_apply_in_pandas$ + df = spark.createDataFrame( + [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ("id", "v")) - from pyspark.sql.functions import pandas_udf, PandasUDFType + def subtract_mean(pdf): + # pdf is a pandas.DataFrame + v = pdf.v + return pdf.assign(v=v - v.mean()) + + df.groupby("id").applyInPandas(subtract_mean, schema="id long, v double").show() + # +---+----+ + # | id| v| + # +---+----+ + # | 1|-0.5| + # | 1| 0.5| + # | 2|-3.0| + # | 2|-1.0| + # | 2| 4.0| + # +---+----+ + # $example off:grouped_apply_in_pandas$ + +def map_in_pandas_example(spark): + # $example on:map_in_pandas$ df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) - @pandas_udf(df.schema, PandasUDFType.MAP_ITER) - def filter_func(batch_iter): - for pdf in batch_iter: + def filter_func(iterator): + for pdf in iterator: yield pdf[pdf.id == 1] - df.mapInPandas(filter_func).show() + df.mapInPandas(filter_func, schema=df.schema).show() # +---+---+ # | id|age| # +---+---+ # | 1| 21| # +---+---+ - # $example off:map_iter_pandas_udf$ + # $example off:map_in_pandas$ -def cogrouped_map_pandas_udf_example(spark): - # $example on:cogrouped_map_pandas_udf$ +def cogrouped_apply_in_pandas_example(spark): + # $example on:cogrouped_apply_in_pandas$ import pandas as pd - from pyspark.sql.functions import pandas_udf, PandasUDFType - df1 = spark.createDataFrame( [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], ("time", "id", "v1")) @@ -272,11 +280,11 @@ def cogrouped_map_pandas_udf_example(spark): [(20000101, 1, "x"), (20000101, 2, "y")], ("time", "id", "v2")) - @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) def asof_join(l, r): return pd.merge_asof(l, r, on="time", by="id") - df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() + df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas( + asof_join, schema="time int, id int, v1 double, v2 string").show() # +--------+---+---+---+ # | time| id| v1| v2| # +--------+---+---+---+ @@ -285,7 +293,7 @@ def asof_join(l, r): # |20000101| 2|2.0| y| # |20000102| 2|4.0| y| # +--------+---+---+---+ - # $example off:cogrouped_map_pandas_udf$ + # $example off:cogrouped_apply_in_pandas$ if __name__ == "__main__": @@ -296,17 +304,21 @@ def asof_join(l, r): print("Running Pandas to/from conversion example") dataframe_with_arrow_example(spark) - print("Running pandas_udf scalar example") - scalar_pandas_udf_example(spark) - print("Running pandas_udf scalar iterator example") - scalar_iter_pandas_udf_example(spark) - print("Running pandas_udf grouped map example") - grouped_map_pandas_udf_example(spark) - print("Running pandas_udf grouped agg example") - grouped_agg_pandas_udf_example(spark) - print("Running pandas_udf map iterator example") - map_iter_pandas_udf_example(spark) - print("Running pandas_udf cogrouped map example") - cogrouped_map_pandas_udf_example(spark) + print("Running pandas_udf example: Series to Frame") + ser_to_frame_pandas_udf_example(spark) + print("Running pandas_udf example: Series to Series") + ser_to_ser_pandas_udf_example(spark) + print("Running pandas_udf example: Iterator of Series to Iterator of Seires") + iter_ser_to_iter_ser_pandas_udf_example(spark) + print("Running pandas_udf example: Iterator of Multiple Series to Iterator of Series") + iter_sers_to_iter_ser_pandas_udf_example(spark) + print("Running pandas_udf example: Series to Scalar") + ser_to_scalar_pandas_udf_example(spark) + print("Running pandas function example: Grouped Map") + grouped_apply_in_pandas_example(spark) + print("Running pandas function example: Map") + map_in_pandas_example(spark) + print("Running pandas function example: Co-grouped Map") + cogrouped_apply_in_pandas_example(spark) spark.stop() diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 30602789a33a9..31aa321bf5826 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -43,303 +43,228 @@ class PandasUDFType(object): @since(2.3) def pandas_udf(f=None, returnType=None, functionType=None): """ - Creates a vectorized user defined function (UDF). + Creates a pandas user defined function (a.k.a. vectorized user defined function). + + Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer + data and Pandas to work with the data, which allows vectorized operations. A Pandas UDF + is defined using the `pandas_udf` as a decorator or to wrap the function, and no + additional configuration is required. A Pandas UDF behaves as a regular PySpark function + API in general. :param f: user-defined function. A python function if used as a standalone function :param returnType: the return type of the user-defined function. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`. - Default: SCALAR. - - .. seealso:: :meth:`pyspark.sql.DataFrame.mapInPandas` - .. seealso:: :meth:`pyspark.sql.GroupedData.applyInPandas` - .. seealso:: :meth:`pyspark.sql.PandasCogroupedOps.applyInPandas` - - The function type of the UDF can be one of the following: - - 1. SCALAR - - A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`. - The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. - If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`. - - :class:`MapType`, nested :class:`StructType` are currently not supported as output types. - - Scalar UDFs can be used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) # doctest: +SKIP - >>> @pandas_udf(StringType()) # doctest: +SKIP - ... def to_upper(s): - ... return s.str.upper() - ... - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 - ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], - ... ("id", "name", "age")) # doctest: +SKIP - >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP - +----------+--------------+------------+ - |slen(name)|to_upper(name)|add_one(age)| - +----------+--------------+------------+ - | 8| JOHN DOE| 22| - +----------+--------------+------------+ - >>> @pandas_udf("first string, last string") # doctest: +SKIP - ... def split_expand(n): - ... return n.str.split(expand=True) - >>> df.select(split_expand("name")).show() # doctest: +SKIP - +------------------+ - |split_expand(name)| - +------------------+ - | [John, Doe]| - +------------------+ - - .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input - column, but is the length of an internal batch used for each call to the function. - Therefore, this can be used, for example, to ensure the length of each returned - `pandas.Series`, and can not be used as the column length. - - 2. SCALAR_ITER - - A scalar iterator UDF is semantically the same as the scalar Pandas UDF above except that the - wrapped Python function takes an iterator of batches as input instead of a single batch and, - instead of returning a single output batch, it yields output batches or explicitly returns an - generator or an iterator of output batches. - It is useful when the UDF execution requires initializing some state, e.g., loading a machine - learning model file to apply inference to every input batch. - - .. note:: It is not guaranteed that one invocation of a scalar iterator UDF will process all - batches from one partition, although it is currently implemented this way. - Your code shall not rely on this behavior because it might change in the future for - further optimization, e.g., one invocation processes multiple partitions. - - Scalar iterator UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and - :meth:`pyspark.sql.DataFrame.select`. - - >>> import pandas as pd # doctest: +SKIP - >>> from pyspark.sql.functions import col, pandas_udf, struct, PandasUDFType - >>> pdf = pd.DataFrame([1, 2, 3], columns=["x"]) # doctest: +SKIP - >>> df = spark.createDataFrame(pdf) # doctest: +SKIP - - When the UDF is called with a single column that is not `StructType`, the input to the - underlying function is an iterator of `pd.Series`. - - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def plus_one(batch_iter): - ... for x in batch_iter: - ... yield x + 1 - ... - >>> df.select(plus_one(col("x"))).show() # doctest: +SKIP - +-----------+ - |plus_one(x)| - +-----------+ - | 2| - | 3| - | 4| - +-----------+ - - When the UDF is called with more than one columns, the input to the underlying function is an - iterator of `pd.Series` tuple. - - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def multiply_two_cols(batch_iter): - ... for a, b in batch_iter: - ... yield a * b - ... - >>> df.select(multiply_two_cols(col("x"), col("x"))).show() # doctest: +SKIP - +-----------------------+ - |multiply_two_cols(x, x)| - +-----------------------+ - | 1| - | 4| - | 9| - +-----------------------+ - - When the UDF is called with a single column that is `StructType`, the input to the underlying - function is an iterator of `pd.DataFrame`. - - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def multiply_two_nested_cols(pdf_iter): - ... for pdf in pdf_iter: - ... yield pdf["a"] * pdf["b"] - ... - >>> df.select( - ... multiply_two_nested_cols( - ... struct(col("x").alias("a"), col("x").alias("b")) - ... ).alias("y") - ... ).show() # doctest: +SKIP - +---+ - | y| - +---+ - | 1| - | 4| - | 9| - +---+ - - In the UDF, you can initialize some states before processing batches, wrap your code with - `try ... finally ...` or use context managers to ensure the release of resources at the end - or in case of early termination. - - >>> y_bc = spark.sparkContext.broadcast(1) # doctest: +SKIP - >>> @pandas_udf("long", PandasUDFType.SCALAR_ITER) # doctest: +SKIP - ... def plus_y(batch_iter): - ... y = y_bc.value # initialize some state - ... try: - ... for x in batch_iter: - ... yield x + y - ... finally: - ... pass # release resources here, if any - ... - >>> df.select(plus_y(col("x"))).show() # doctest: +SKIP - +---------+ - |plus_y(x)| - +---------+ - | 2| - | 3| - | 4| - +---------+ - - 3. GROUPED_MAP - - A grouped map UDF defines transformation: A `pandas.DataFrame` -> A `pandas.DataFrame` - The returnType should be a :class:`StructType` describing the schema of the returned - `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match - the field names in the defined returnType schema if specified as strings, or match the - field data types by position if not strings, e.g. integer indices. - The length of the returned `pandas.DataFrame` can be arbitrary. - - Grouped map UDFs are used with :meth:`pyspark.sql.GroupedData.apply`. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP - ... def normalize(pdf): - ... v = pdf.v - ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP - +---+-------------------+ - | id| v| - +---+-------------------+ - | 1|-0.7071067811865475| - | 1| 0.7071067811865475| - | 2|-0.8320502943378437| - | 2|-0.2773500981126146| - | 2| 1.1094003924504583| - +---+-------------------+ - - Alternatively, the user can define a function that takes two arguments. - In this case, the grouping key(s) will be passed as the first argument and the data will - be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy - data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in - as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. - This is useful when the user does not want to hardcode grouping key(s) in the function. - - >>> import pandas as pd # doctest: +SKIP - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) # doctest: +SKIP - >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP - ... def mean_udf(key, pdf): - ... # key is a tuple of one numpy.int64, which is the value - ... # of 'id' for the current group - ... return pd.DataFrame([key + (pdf.v.mean(),)]) - >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP - +---+---+ - | id| v| - +---+---+ - | 1|1.5| - | 2|6.0| - +---+---+ - >>> @pandas_udf( - ... "id long, `ceil(v / 2)` long, v double", - ... PandasUDFType.GROUPED_MAP) # doctest: +SKIP - >>> def sum_udf(key, pdf): - ... # key is a tuple of two numpy.int64s, which is the values - ... # of 'id' and 'ceil(df.v / 2)' for the current group - ... return pd.DataFrame([key + (pdf.v.sum(),)]) - >>> df.groupby(df.id, ceil(df.v / 2)).apply(sum_udf).show() # doctest: +SKIP - +---+-----------+----+ - | id|ceil(v / 2)| v| - +---+-----------+----+ - | 2| 5|10.0| - | 1| 1| 3.0| - | 2| 3| 5.0| - | 2| 2| 3.0| - +---+-----------+----+ - - .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is - recommended to explicitly index the columns by name to ensure the positions are correct, - or alternatively use an `OrderedDict`. - For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or - `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. - - .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - - 4. GROUPED_AGG - - A grouped aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar - The `returnType` should be a primitive data type, e.g., :class:`DoubleType`. - The returned scalar can be either a python primitive type, e.g., `int` or `float` - or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. - - :class:`MapType` and :class:`StructType` are currently not supported as output types. - - Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and - :class:`pyspark.sql.Window` - - This example shows using grouped aggregated UDFs with groupby: - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP - ... def mean_udf(v): - ... return v.mean() - >>> df.groupby("id").agg(mean_udf(df['v'])).show() # doctest: +SKIP - +---+-----------+ - | id|mean_udf(v)| - +---+-----------+ - | 1| 1.5| - | 2| 6.0| - +---+-----------+ - - This example shows using grouped aggregated UDFs as window functions. - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> from pyspark.sql import Window - >>> df = spark.createDataFrame( - ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) - >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP - ... def mean_udf(v): - ... return v.mean() - >>> w = (Window.partitionBy('id') - ... .orderBy('v') - ... .rowsBetween(-1, 0)) - >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP - +---+----+------+ - | id| v|mean_v| - +---+----+------+ - | 1| 1.0| 1.0| - | 1| 2.0| 1.5| - | 2| 3.0| 3.0| - | 2| 5.0| 4.0| - | 2|10.0| 7.5| - +---+----+------+ - - .. note:: For performance reasons, the input series to window functions are not copied. + Default: SCALAR. + + .. note:: This parameter exists for compatibility. Using Python type hints is encouraged. + + In order to use this API, customarily the below are imported: + + >>> import pandas as pd + >>> from pyspark.sql.functions import pandas_udf + + From Spark 3.0 with Python 3.6+, `Python type hints `_ + detect the function types as below: + + >>> @pandas_udf(IntegerType()) + ... def slen(s: pd.Series) -> pd.Series: + ... return s.str.len() + + Prior to Spark 3.0, the pandas UDF used `functionType` to decide the execution type as below: + + >>> from pyspark.sql.functions import PandasUDFType + >>> from pyspark.sql.types import IntegerType + >>> @pandas_udf(IntegerType(), PandasUDFType.SCALAR) + ... def slen(s): + ... return s.str.len() + + It is preferred to specify type hints for the pandas UDF instead of specifying pandas UDF + type via `functionType` which will be deprecated in the future releases. + + Note that the type hint should use `pandas.Series` in all cases but there is one variant + that `pandas.DataFrame` should be used for its input or output type hint instead when the input + or output column is of :class:`pyspark.sql.types.StructType`. The following example shows + a Pandas UDF which takes long column, string column and struct column, and outputs a struct + column. It requires the function to specify the type hints of `pandas.Series` and + `pandas.DataFrame` as below: + + >>> @pandas_udf("col1 string, col2 long") + >>> def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame: + ... s3['col2'] = s1 + s2.str.len() + ... return s3 + ... + >>> # Create a Spark DataFrame that has three columns including a sturct column. + ... df = spark.createDataFrame( + ... [[1, "a string", ("a nested string",)]], + ... "long_col long, string_col string, struct_col struct") + >>> df.printSchema() + root + |-- long_column: long (nullable = true) + |-- string_column: string (nullable = true) + |-- struct_column: struct (nullable = true) + | |-- col1: string (nullable = true) + >>> df.select(func("long_col", "string_col", "struct_col")).printSchema() + |-- func(long_col, string_col, struct_col): struct (nullable = true) + | |-- col1: string (nullable = true) + | |-- col2: long (nullable = true) + + In the following sections, it describes the cominations of the supported type hints. For + simplicity, `pandas.DataFrame` variant is omitted. + + * Series to Series + `pandas.Series`, ... -> `pandas.Series` + + The function takes one or more `pandas.Series` and outputs one `pandas.Series`. + The output of the function should always be of the same length as the input. + + >>> @pandas_udf("string") + ... def to_upper(s: pd.Series) -> pd.Series: + ... return s.str.upper() + ... + >>> df = spark.createDataFrame([("John Doe",)], ("name",)) + >>> df.select(to_upper("name")).show() + +--------------+ + |to_upper(name)| + +--------------+ + | JOHN DOE| + +--------------+ + + >>> @pandas_udf("first string, last string") + ... def split_expand(s: pd.Series) -> pd.DataFrame: + ... return s.str.split(expand=True) + ... + >>> df = spark.createDataFrame([("John Doe",)], ("name",)) + >>> df.select(split_expand("name")).show() + +------------------+ + |split_expand(name)| + +------------------+ + | [John, Doe]| + +------------------+ + + .. note:: The length of the input is not that of the whole input column, but is the + length of an internal batch used for each call to the function. + + * Iterator of Series to Iterator of Series + `Iterator[pandas.Series]` -> `Iterator[pandas.Series]` + + The function takes an iterator of `pandas.Series` and outputs an iterator of + `pandas.Series`. In this case, the created pandas UDF instance requires one input + column when this is called as a PySpark column. The output of each series from + the function should always be of the same length as the input. + + It is useful when the UDF execution + requires initializing some states although internally it works identically as + Series to Series case. The pseudocode below illustrates the example. + + .. highlight:: python + .. code-block:: python + + @pandas_udf("long") + def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + # Do some expensive initialization with a state + state = very_expensive_initialization() + for x in iterator: + # Use that state for whole iterator. + yield calculate_with_state(x, state) + + df.select(calculate("value")).show() + + >>> from typing import Iterator + >>> @pandas_udf("long") + ... def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: + ... for s in iterator: + ... yield s + 1 + ... + >>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"])) + >>> df.select(plus_one(df.v)).show() + +-----------+ + |plus_one(v)| + +-----------+ + | 2| + | 3| + | 4| + +-----------+ + + .. note:: The length of each series is the length of a batch internally used. + + * Iterator of Multiple Series to Iterator of Series + `Iterator[Tuple[pandas.Series, ...]]` -> `Iterator[pandas.Series]` + + The function takes an iterator of a tuple of multiple `pandas.Series` and outputs an + iterator of `pandas.Series`. In this case, the created pandas UDF instance requires + input columns as many as the series when this is called as a PySpark column. + It works identically as Iterator of Series to Iterator of Series case except + the parameter difference. The output of each series from the function should always + be of the same length as the input. + + >>> from typing import Iterator, Tuple + >>> from pyspark.sql.functions import struct, col + >>> @pandas_udf("long") + ... def multiply(iterator: Iterator[Tuple[pd.Series, pd.DataFrame]]) -> Iterator[pd.Series]: + ... for s1, df in iterator: + ... yield s1 * df.v + ... + >>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"])) + >>> df.withColumn('output', multiply(col("v"), struct(col("v")))).show() + +---+------+ + | v|output| + +---+------+ + | 1| 1| + | 2| 4| + | 3| 9| + +---+------+ + + .. note:: The length of each series is the length of a batch internally used. + + * Series to Scalar + `pandas.Series`, ... -> `Any` + + The function takes `pandas.Series` and returns a scalar value. The `returnType` + should be a primitive data type, and the returned scalar can be either a python primitive + type, e.g., int or float or a numpy data type, e.g., numpy.int64 or numpy.float64. + `Any` should ideally be a specific scalar type accordingly. + + >>> @pandas_udf("double") + ... def mean_udf(v: pd.Series) -> float: + ... return v.mean() + ... + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) + >>> df.groupby("id").agg(mean_udf(df['v'])).show() + +---+-----------+ + | id|mean_udf(v)| + +---+-----------+ + | 1| 1.5| + | 2| 6.0| + +---+-----------+ + + This UDF can also be used as window functions as below: + + >>> from pyspark.sql import Window + >>> @pandas_udf("double") + ... def mean_udf(v: pd.Series) -> float: + ... return v.mean() + ... + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) + >>> w = Window.partitionBy('id').orderBy('v').rowsBetween(-1, 0) + >>> df.withColumn('mean_v', mean_udf("v").over(w)).show() + +---+----+------+ + | id| v|mean_v| + +---+----+------+ + | 1| 1.0| 1.0| + | 1| 2.0| 1.5| + | 2| 3.0| 3.0| + | 2| 5.0| 4.0| + | 2|10.0| 7.5| + +---+----+------+ + + .. note:: For performance reasons, the input series to window functions are not copied. Therefore, mutating the input series is not allowed and will cause incorrect results. For the same reason, users should also not rely on the index of the input series. - .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` + .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` .. note:: The user-defined functions do not support conditional expressions or short circuiting in boolean expressions and it ends up with being executed all internally. If the functions @@ -348,10 +273,21 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. note:: The user-defined functions do not take keyword arguments on the calling side. .. note:: The data type of returned `pandas.Series` from the user-defined functions should be - matched with defined returnType (see :meth:`types.to_arrow_type` and + matched with defined `returnType` (see :meth:`types.to_arrow_type` and :meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do conversion on returned data. The conversion is not guaranteed to be correct and results should be checked for accuracy by users. + + .. note:: Currently, + :class:`pyspark.sql.types.MapType`, + :class:`pyspark.sql.types.ArrayType` of :class:`pyspark.sql.types.TimestampType` and + nested :class:`pyspark.sql.types.StructType` + are currently not supported as output types. + + .. seealso:: :meth:`pyspark.sql.DataFrame.mapInPandas` + .. seealso:: :meth:`pyspark.sql.GroupedData.applyInPandas` + .. seealso:: :meth:`pyspark.sql.PandasCogroupedOps.applyInPandas` + .. seealso:: :meth:`pyspark.sql.UDFRegistration.register` """ # The following table shows most of Pandas data and SQL type conversions in Pandas UDFs that @@ -480,25 +416,3 @@ def _create_pandas_udf(f, returnType, evalType): "or three arguments (key, left, right).") return _create_udf(f, returnType, evalType) - - -def _test(): - import doctest - from pyspark.sql import SparkSession - import pyspark.sql.pandas.functions - globs = pyspark.sql.pandas.functions.__dict__.copy() - spark = SparkSession.builder\ - .master("local[4]")\ - .appName("sql.pandas.functions tests")\ - .getOrCreate() - globs['spark'] = spark - (failure_count, test_count) = doctest.testmod( - pyspark.sql.pandas.functions, globs=globs, - optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF) - spark.stop() - if failure_count: - sys.exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 3152271ba9df8..b93f0516cadb1 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -88,29 +88,27 @@ def applyInPandas(self, func, schema): to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. - The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the - returnType of the pandas udf. - - .. note:: This function requires a full shuffle. All the data of a group will be loaded - into memory, so the user should be aware of the potential OOM risk if data is skewed - and certain groups are too large to fit in memory. + The `schema` should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pandas.DataFrame` can be arbitrary. :param func: a Python native function that takes a `pandas.DataFrame`, and outputs a `pandas.DataFrame`. :param schema: the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - .. note:: Experimental - - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.functions import pandas_udf, ceil >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], - ... ("id", "v")) + ... ("id", "v")) # doctest: +SKIP >>> def normalize(pdf): ... v = pdf.v ... return pdf.assign(v=(v - v.mean()) / v.std()) - >>> df.groupby("id").applyInPandas(normalize, schema="id long, v double").show() - ... # doctest: +SKIP + >>> df.groupby("id").applyInPandas( + ... normalize, schema="id long, v double").show() # doctest: +SKIP +---+-------------------+ | id| v| +---+-------------------+ @@ -121,8 +119,56 @@ def applyInPandas(self, func, schema): | 2| 1.1094003924504583| +---+-------------------+ - .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + Alternatively, the user can pass a function that takes two arguments. + In this case, the grouping key(s) will be passed as the first argument and the data will + be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy + data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in + as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key(s) in the function. + + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> def mean_func(key, pdf): + ... # key is a tuple of one numpy.int64, which is the value + ... # of 'id' for the current group + ... return pd.DataFrame([key + (pdf.v.mean(),)]) + >>> df.groupby('id').applyInPandas( + ... mean_func, schema="id long, v double").show() # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + >>> def sum_func(key, pdf): + ... # key is a tuple of two numpy.int64s, which is the values + ... # of 'id' and 'ceil(df.v / 2)' for the current group + ... return pd.DataFrame([key + (pdf.v.sum(),)]) + >>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas( + ... sum_func, schema="id long, `ceil(v / 2)` long, v double").show() # doctest: +SKIP + +---+-----------+----+ + | id|ceil(v / 2)| v| + +---+-----------+----+ + | 2| 5|10.0| + | 1| 1| 3.0| + | 2| 3| 5.0| + | 2| 2| 3.0| + +---+-----------+----+ + + .. note:: This function requires a full shuffle. All the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + .. note:: Experimental + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ from pyspark.sql import GroupedData from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -176,14 +222,11 @@ def applyInPandas(self, func, schema): `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` are combined as a :class:`DataFrame`. - The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the - returnType of the pandas udf. - - .. note:: This function requires a full shuffle. All the data of a cogroup will be loaded - into memory, so the user should be aware of the potential OOM risk if data is skewed - and certain groups are too large to fit in memory. - - .. note:: Experimental + The `schema` should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pandas.DataFrame` can be arbitrary. :param func: a Python native function that takes two `pandas.DataFrame`\\s, and outputs a `pandas.DataFrame`, or that takes one tuple (grouping keys) and two @@ -191,7 +234,7 @@ def applyInPandas(self, func, schema): :param schema: the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql.functions import pandas_udf >>> df1 = spark.createDataFrame( ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], ... ("time", "id", "v1")) @@ -232,6 +275,18 @@ def applyInPandas(self, func, schema): |20000102| 1|3.0| x| +--------+---+---+---+ + .. note:: This function requires a full shuffle. All the data of a cogroup will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is + recommended to explicitly index the columns by name to ensure the positions are correct, + or alternatively use an `OrderedDict`. + For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or + `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`. + + .. note:: Experimental + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` """ diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py index 75cacd797f9dd..9835e88c6ac21 100644 --- a/python/pyspark/sql/pandas/map_ops.py +++ b/python/pyspark/sql/pandas/map_ops.py @@ -45,10 +45,10 @@ def mapInPandas(self, func, schema): :param schema: the return type of the `func` in PySpark. The value can be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> from pyspark.sql.functions import pandas_udf >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")) - >>> def filter_func(batch_iter): - ... for pdf in batch_iter: + >>> def filter_func(iterator): + ... for pdf in iterator: ... yield pdf[pdf.id == 1] >>> df.mapInPandas(filter_func, df.schema).show() # doctest: +SKIP +---+---+ diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 433c5fc845c59..10546ecacc57f 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -297,17 +297,18 @@ def register(self, name, f, returnType=None): >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP [Row(random_udf()=82)] - >>> from pyspark.sql.functions import pandas_udf, PandasUDFType - >>> @pandas_udf("integer", PandasUDFType.SCALAR) # doctest: +SKIP - ... def add_one(x): - ... return x + 1 + >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.functions import pandas_udf + >>> @pandas_udf("integer") # doctest: +SKIP + ... def add_one(s: pd.Series) -> pd.Series: + ... return s + 1 ... >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)] - >>> @pandas_udf("integer", PandasUDFType.GROUPED_AGG) # doctest: +SKIP - ... def sum_udf(v): + >>> @pandas_udf("integer") # doctest: +SKIP + ... def sum_udf(v: pd.Series) -> int: ... return v.sum() ... >>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP @@ -414,6 +415,9 @@ def _test(): .appName("sql.udf tests")\ .getOrCreate() globs['spark'] = spark + # Hack to skip the unit tests in register. These are currently being tested in proper tests. + # We should reenable this test once we completely drop Python 2. + del pyspark.sql.udf.UDFRegistration.register (failure_count, test_count) = doctest.testmod( pyspark.sql.udf, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) From b4769998efee0f5998104b689b710c11ee0dbd14 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Wed, 12 Feb 2020 15:19:16 +0900 Subject: [PATCH 021/185] [SPARK-30795][SQL] Spark SQL codegen's code() interpolator should treat escapes like Scala's StringContext.s() ### What changes were proposed in this pull request? This PR proposes to make the `code` string interpolator treat escapes the same way as Scala's builtin `StringContext.s()` string interpolator. This will remove the need for an ugly workaround in `Like` expression's codegen. ### Why are the changes needed? The `code()` string interpolator in Spark SQL's code generator should treat escapes like Scala's builtin `StringContext.s()` interpolator, i.e. it should treat escapes in the code parts, and should not treat escapes in the input arguments. For example, ```scala val arg = "This is an argument." val str = s"This is string part 1. $arg This is string part 2." val code = code"This is string part 1. $arg This is string part 2." assert(code.toString == str) ``` We should expect the `code()` interpolator to produce the same result as the `StringContext.s()` interpolator, where only escapes in the string parts should be treated, while the args should be kept verbatim. But in the current implementation, due to the eager folding of code parts and literal input args, the escape treatment is incorrectly done on both code parts and literal args. That causes a problem when an arg contains escape sequences and wants to preserve that in the final produced code string. For example, in `Like` expression's codegen, there's an ugly workaround for this bug: ```scala // We need double escape to avoid org.codehaus.commons.compiler.CompileException. // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. // '\"' will cause exception 'Line break in literal not allowed'. val newEscapeChar = if (escapeChar == '\"' || escapeChar == '\\') { s"""\\\\\\$escapeChar""" } else { escapeChar } ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added a new unit test case in `CodeBlockSuite`. Closes #27544 from rednaxelafx/fix-code-string-interpolator. Authored-by: Kris Mok Signed-off-by: HyukjinKwon --- .../sql/catalyst/expressions/codegen/javaCode.scala | 13 +++++++++---- .../catalyst/expressions/regexpExpressions.scala | 13 ++++--------- .../expressions/codegen/CodeBlockSuite.scala | 12 ++++++++++++ 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index d9393b9df6bbd..dff258902a0b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -223,6 +223,11 @@ object Block { implicit def blocksToBlock(blocks: Seq[Block]): Block = blocks.reduceLeft(_ + _) implicit class BlockHelper(val sc: StringContext) extends AnyVal { + /** + * A string interpolator that retains references to the `JavaCode` inputs, and behaves like + * the Scala builtin StringContext.s() interpolator otherwise, i.e. it will treat escapes in + * the code parts, and will not treat escapes in the input arguments. + */ def code(args: Any*): Block = { sc.checkLengths(args) if (sc.parts.length == 0) { @@ -250,7 +255,7 @@ object Block { val inputs = args.iterator val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) - buf.append(strings.next) + buf.append(StringContext.treatEscapes(strings.next)) while (strings.hasNext) { val input = inputs.next input match { @@ -262,7 +267,7 @@ object Block { case _ => buf.append(input) } - buf.append(strings.next) + buf.append(StringContext.treatEscapes(strings.next)) } codeParts += buf.toString @@ -286,10 +291,10 @@ case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends val strings = codeParts.iterator val inputs = blockInputs.iterator val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) - buf.append(StringContext.treatEscapes(strings.next)) + buf.append(strings.next) while (strings.hasNext) { buf.append(inputs.next) - buf.append(StringContext.treatEscapes(strings.next)) + buf.append(strings.next) } buf.toString } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index e5ee0edfcf79b..6a4d813d345b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -159,19 +159,14 @@ case class Like(left: Expression, right: Expression, escapeChar: Char) } else { val pattern = ctx.freshName("pattern") val rightStr = ctx.freshName("rightStr") - // We need double escape to avoid org.codehaus.commons.compiler.CompileException. - // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. - // '\"' will cause exception 'Line break in literal not allowed'. - val newEscapeChar = if (escapeChar == '\"' || escapeChar == '\\') { - s"""\\\\\\$escapeChar""" - } else { - escapeChar - } + // We need to escape the escapeChar to make sure the generated code is valid. + // Otherwise we'll hit org.codehaus.commons.compiler.CompileException. + val escapedEscapeChar = StringEscapeUtils.escapeJava(escapeChar.toString) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" String $rightStr = $eval2.toString(); $patternClass $pattern = $patternClass.compile( - $escapeFunc($rightStr, '$newEscapeChar')); + $escapeFunc($rightStr, '$escapedEscapeChar')); ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala index 55569b6f2933e..67e3bc69543e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala @@ -37,6 +37,18 @@ class CodeBlockSuite extends SparkFunSuite { assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value)) } + test("Code parts should be treated for escapes, but string inputs shouldn't be") { + val strlit = raw"\\" + val code = code"""String s = "foo\\bar" + "$strlit";""" + + val builtin = s"""String s = "foo\\bar" + "$strlit";""" + + val expected = raw"""String s = "foo\bar" + "\\";""" + + assert(builtin == expected) + assert(code.asInstanceOf[CodeBlock].toString == expected) + } + test("Block.stripMargin") { val isNull = JavaCode.isNullVariable("expr1_isNull") val value = JavaCode.variable("expr1", IntegerType) From f5026b1ba7c05548d5f271d6d3edf7dfd4c3f9ce Mon Sep 17 00:00:00 2001 From: beliefer Date: Wed, 12 Feb 2020 14:49:22 +0800 Subject: [PATCH 022/185] [SPARK-30763][SQL] Fix java.lang.IndexOutOfBoundsException No group 1 for regexp_extract ### What changes were proposed in this pull request? The current implement of `regexp_extract` will throws a unprocessed exception show below: `SELECT regexp_extract('1a 2b 14m', 'd+')` ``` java.lang.IndexOutOfBoundsException: No group 1 [info] at java.util.regex.Matcher.group(Matcher.java:538) [info] at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) [info] at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) [info] at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:729) ``` I think should treat this exception well. ### Why are the changes needed? Fix a bug `java.lang.IndexOutOfBoundsException No group 1 ` ### Does this PR introduce any user-facing change? Yes ### How was this patch tested? New UT Closes #27508 from beliefer/fix-regexp_extract-bug. Authored-by: beliefer Signed-off-by: Wenchen Fan --- .../expressions/regexpExpressions.scala | 15 +++- .../expressions/RegexpExpressionsSuite.scala | 12 ++++ .../sql-tests/inputs/regexp-functions.sql | 9 +++ .../results/regexp-functions.sql.out | 69 +++++++++++++++++++ 4 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 6a4d813d345b3..3f60ca388a807 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -410,6 +410,15 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } } +object RegExpExtract { + def checkGroupIndex(groupCount: Int, groupIndex: Int): Unit = { + if (groupCount < groupIndex) { + throw new IllegalArgumentException( + s"Regex group count is $groupCount, but the specified group index is $groupIndex") + } + } +} + /** * Extract a specific(idx) group identified by a Java regex. * @@ -441,7 +450,9 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio val m = pattern.matcher(s.toString) if (m.find) { val mr: MatchResult = m.toMatchResult - val group = mr.group(r.asInstanceOf[Int]) + val index = r.asInstanceOf[Int] + RegExpExtract.checkGroupIndex(mr.groupCount, index) + val group = mr.group(index) if (group == null) { // Pattern matched, but not optional group UTF8String.EMPTY_UTF8 } else { @@ -459,6 +470,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val classNamePattern = classOf[Pattern].getCanonicalName + val classNameRegExpExtract = classOf[RegExpExtract].getCanonicalName val matcher = ctx.freshName("matcher") val matchResult = ctx.freshName("matchResult") @@ -482,6 +494,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio $termPattern.matcher($subject.toString()); if ($matcher.find()) { java.util.regex.MatchResult $matchResult = $matcher.toMatchResult(); + $classNameRegExpExtract.checkGroupIndex($matchResult.groupCount(), $idx); if ($matchResult.group($idx) == null) { ${ev.value} = UTF8String.EMPTY_UTF8; } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 2c8794f083dbb..86da62bc74940 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -293,6 +293,18 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val nonNullExpr = RegExpExtract(Literal("100-200"), Literal("(\\d+)-(\\d+)"), Literal(1)) checkEvaluation(nonNullExpr, "100", row1) + + // invalid group index + val row8 = create_row("100-200", "(\\d+)-(\\d+)", 3) + val row9 = create_row("100-200", "(\\d+).*", 2) + val row10 = create_row("100-200", "\\d+", 1) + + checkExceptionInExpression[IllegalArgumentException]( + expr, row8, "Regex group count is 2, but the specified group index is 3") + checkExceptionInExpression[IllegalArgumentException]( + expr, row9, "Regex group count is 1, but the specified group index is 2") + checkExceptionInExpression[IllegalArgumentException]( + expr, row10, "Regex group count is 0, but the specified group index is 1") } test("SPLIT") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql new file mode 100644 index 0000000000000..c0827a3cba39b --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/regexp-functions.sql @@ -0,0 +1,9 @@ +-- regexp_extract +SELECT regexp_extract('1a 2b 14m', '\\d+'); +SELECT regexp_extract('1a 2b 14m', '\\d+', 0); +SELECT regexp_extract('1a 2b 14m', '\\d+', 1); +SELECT regexp_extract('1a 2b 14m', '\\d+', 2); +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)'); +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 0); +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 1); +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 2); diff --git a/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out new file mode 100644 index 0000000000000..c92c1ddca774f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/regexp-functions.sql.out @@ -0,0 +1,69 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 8 + + +-- !query +SELECT regexp_extract('1a 2b 14m', '\\d+') +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Regex group count is 0, but the specified group index is 1 + + +-- !query +SELECT regexp_extract('1a 2b 14m', '\\d+', 0) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT regexp_extract('1a 2b 14m', '\\d+', 1) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Regex group count is 0, but the specified group index is 1 + + +-- !query +SELECT regexp_extract('1a 2b 14m', '\\d+', 2) +-- !query schema +struct<> +-- !query output +java.lang.IllegalArgumentException +Regex group count is 0, but the specified group index is 2 + + +-- !query +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)') +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 0) +-- !query schema +struct +-- !query output +1a + + +-- !query +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 1) +-- !query schema +struct +-- !query output +1 + + +-- !query +SELECT regexp_extract('1a 2b 14m', '(\\d+)([a-z]+)', 2) +-- !query schema +struct +-- !query output +a From 8b1839728acaa5e61f542a7332505289726d3162 Mon Sep 17 00:00:00 2001 From: turbofei Date: Wed, 12 Feb 2020 20:21:52 +0900 Subject: [PATCH 023/185] [SPARK-29542][FOLLOW-UP] Keep the description of spark.sql.files.* in tuning guide be consistent with that in SQLConf ### What changes were proposed in this pull request? This pr is a follow up of https://github.com/apache/spark/pull/26200. In this PR, I modify the description of spark.sql.files.* in sql-performance-tuning.md to keep consistent with that in SQLConf. ### Why are the changes needed? To keep consistent with the description in SQLConf. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existed UT. Closes #27545 from turboFei/SPARK-29542-follow-up. Authored-by: turbofei Signed-off-by: HyukjinKwon --- docs/sql-performance-tuning.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/sql-performance-tuning.md b/docs/sql-performance-tuning.md index e289854c7acc7..5a86c0cc31e12 100644 --- a/docs/sql-performance-tuning.md +++ b/docs/sql-performance-tuning.md @@ -67,6 +67,7 @@ that these options will be deprecated in future release as more optimizations ar 134217728 (128 MB) The maximum number of bytes to pack into a single partition when reading files. + This configuration is effective only when using file-based sources such as Parquet, JSON and ORC. @@ -76,7 +77,8 @@ that these options will be deprecated in future release as more optimizations ar The estimated cost to open a file, measured by the number of bytes could be scanned in the same time. This is used when putting multiple files into a partition. It is better to over-estimated, then the partitions with small files will be faster than partitions with bigger files (which is - scheduled first). + scheduled first). This configuration is effective only when using file-based sources such as Parquet, + JSON and ORC. From c1986204e59f1e8cc4b611d5a578cb248cb74c28 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 12 Feb 2020 20:12:38 +0800 Subject: [PATCH 024/185] [SPARK-30788][SQL] Support `SimpleDateFormat` and `FastDateFormat` as legacy date/timestamp formatters ### What changes were proposed in this pull request? In the PR, I propose to add legacy date/timestamp formatters based on `SimpleDateFormat` and `FastDateFormat`: - `LegacyFastTimestampFormatter` - uses `FastDateFormat` and supports parsing/formatting in microsecond precision. The code was borrowed from Spark 2.4, see https://github.com/apache/spark/pull/26507 & https://github.com/apache/spark/pull/26582 - `LegacySimpleTimestampFormatter` uses `SimpleDateFormat`, and support the `lenient` mode. When the `lenient` parameter is set to `false`, the parser become much stronger in checking its input. ### Why are the changes needed? Spark 2.4.x uses the following parsers for parsing/formatting date/timestamp strings: - `DateTimeFormat` in CSV/JSON datasource - `SimpleDateFormat` - is used in JDBC datasource, in partitions parsing. - `SimpleDateFormat` in strong mode (`lenient = false`), see https://github.com/apache/spark/blob/branch-2.4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala#L124. It is used by the `date_format`, `from_unixtime`, `unix_timestamp` and `to_unix_timestamp` functions. The PR aims to make Spark 3.0 compatible with Spark 2.4.x in all those cases when `spark.sql.legacy.timeParser.enabled` is set to `true`. ### Does this PR introduce any user-facing change? This shouldn't change behavior with default settings. If `spark.sql.legacy.timeParser.enabled` is set to `true`, users should observe behavior of Spark 2.4. ### How was this patch tested? - Modified tests in `DateExpressionsSuite` to check the legacy parser - `SimpleDateFormat`. - Added `CSVLegacyTimeParserSuite` and `JsonLegacyTimeParserSuite` to run `CSVSuite` and `JsonSuite` with the legacy parser - `FastDateFormat`. Closes #27524 from MaxGekk/timestamp-formatter-legacy-fallback. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/csv/CSVInferSchema.scala | 4 +- .../spark/sql/catalyst/csv/CSVOptions.scala | 4 +- .../sql/catalyst/csv/UnivocityGenerator.scala | 7 +- .../sql/catalyst/csv/UnivocityParser.scala | 7 +- .../expressions/datetimeExpressions.scala | 52 ++- .../spark/sql/catalyst/json/JSONOptions.scala | 4 +- .../sql/catalyst/json/JacksonGenerator.scala | 7 +- .../sql/catalyst/json/JacksonParser.scala | 7 +- .../sql/catalyst/json/JsonInferSchema.scala | 4 +- .../sql/catalyst/util/DateFormatter.scala | 66 ++- .../catalyst/util/TimestampFormatter.scala | 132 +++++- .../org/apache/spark/sql/types/Decimal.scala | 2 +- .../expressions/DateExpressionsSuite.scala | 390 ++++++++++-------- .../org/apache/spark/sql/functions.scala | 7 +- .../resources/test-data/bad_after_good.csv | 2 +- .../resources/test-data/value-malformed.csv | 2 +- .../apache/spark/sql/DateFunctionsSuite.scala | 346 +++++++++------- .../execution/datasources/csv/CSVSuite.scala | 23 +- .../datasources/json/JsonSuite.scala | 7 + 19 files changed, 654 insertions(+), 419 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 03cc3cbdf790a..c6a03183ab45e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -24,6 +24,7 @@ import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.types._ @@ -32,7 +33,8 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { private val timestampParser = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val decimalParser = if (options.locale == Locale.US) { // Special handling the default locale for backward compatibility diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 5e40d74e54f11..8892037e03a7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -146,10 +146,10 @@ class CSVOptions( // A language tag in IETF BCP 47 format val locale: Locale = parameters.get("locale").map(Locale.forLanguageTag).getOrElse(Locale.US) - val dateFormat: String = parameters.getOrElse("dateFormat", "uuuu-MM-dd") + val dateFormat: String = parameters.getOrElse("dateFormat", DateFormatter.defaultPattern) val timestampFormat: String = - parameters.getOrElse("timestampFormat", "uuuu-MM-dd'T'HH:mm:ss.SSSXXX") + parameters.getOrElse("timestampFormat", s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX") val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index 05cb91d10868e..00e3d49787db1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -23,6 +23,7 @@ import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.types._ class UnivocityGenerator( @@ -44,11 +45,13 @@ class UnivocityGenerator( private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private def makeConverter(dataType: DataType): ValueConverter = dataType match { case DateType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 5510953804025..cd69c21a01976 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -86,11 +87,13 @@ class UnivocityParser( private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val csvFilters = new CSVFilters(filters, requiredSchema) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index aa2bd5a1273e0..1f4c8c041f8bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -30,9 +30,10 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, LegacyDateFormats, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -622,13 +623,15 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti @transient private lazy val formatter: Option[TimestampFormatter] = { if (right.foldable) { - Option(right.eval()).map(format => TimestampFormatter(format.toString, zoneId)) + Option(right.eval()).map { format => + TimestampFormatter(format.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) + } } else None } override protected def nullSafeEval(timestamp: Any, format: Any): Any = { val tf = if (formatter.isEmpty) { - TimestampFormatter(format.toString, zoneId) + TimestampFormatter(format.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) } else { formatter.get } @@ -643,10 +646,14 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti }) }.getOrElse { val tf = TimestampFormatter.getClass.getName.stripSuffix("$") + val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$") val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) defineCodeGen(ctx, ev, (timestamp, format) => { - s"""UTF8String.fromString($tf$$.MODULE$$.apply($format.toString(), $zid) - .format($timestamp))""" + s"""|UTF8String.fromString($tf$$.MODULE$$.apply( + | $format.toString(), + | $zid, + | $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT()) + |.format($timestamp))""".stripMargin }) } } @@ -688,7 +695,7 @@ case class ToUnixTimestamp( copy(timeZoneId = Option(timeZoneId)) def this(time: Expression) = { - this(time, Literal("uuuu-MM-dd HH:mm:ss")) + this(time, Literal(TimestampFormatter.defaultPattern)) } override def prettyName: String = "to_unix_timestamp" @@ -732,7 +739,7 @@ case class UnixTimestamp(timeExp: Expression, format: Expression, timeZoneId: Op copy(timeZoneId = Option(timeZoneId)) def this(time: Expression) = { - this(time, Literal("uuuu-MM-dd HH:mm:ss")) + this(time, Literal(TimestampFormatter.defaultPattern)) } def this() = { @@ -758,7 +765,7 @@ abstract class ToTimestamp private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: TimestampFormatter = try { - TimestampFormatter(constFormat.toString, zoneId) + TimestampFormatter(constFormat.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) } catch { case NonFatal(_) => null } @@ -791,8 +798,8 @@ abstract class ToTimestamp } else { val formatString = f.asInstanceOf[UTF8String].toString try { - TimestampFormatter(formatString, zoneId).parse( - t.asInstanceOf[UTF8String].toString) / downScaleFactor + TimestampFormatter(formatString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) + .parse(t.asInstanceOf[UTF8String].toString) / downScaleFactor } catch { case NonFatal(_) => null } @@ -831,13 +838,16 @@ abstract class ToTimestamp } case StringType => val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) - val locale = ctx.addReferenceObj("locale", Locale.US) val tf = TimestampFormatter.getClass.getName.stripSuffix("$") + val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (string, format) => { s""" try { - ${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $zid, $locale) - .parse($string.toString()) / $downScaleFactor; + ${ev.value} = $tf$$.MODULE$$.apply( + $format.toString(), + $zid, + $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT()) + .parse($string.toString()) / $downScaleFactor; } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; } catch (java.text.ParseException e) { @@ -908,7 +918,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override def prettyName: String = "from_unixtime" def this(unix: Expression) = { - this(unix, Literal("uuuu-MM-dd HH:mm:ss")) + this(unix, Literal(TimestampFormatter.defaultPattern)) } override def dataType: DataType = StringType @@ -922,7 +932,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] private lazy val formatter: TimestampFormatter = try { - TimestampFormatter(constFormat.toString, zoneId) + TimestampFormatter(constFormat.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) } catch { case NonFatal(_) => null } @@ -948,8 +958,9 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ null } else { try { - UTF8String.fromString(TimestampFormatter(f.toString, zoneId) - .format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) + UTF8String.fromString( + TimestampFormatter(f.toString, zoneId, legacyFormat = SIMPLE_DATE_FORMAT) + .format(time.asInstanceOf[Long] * MICROS_PER_SECOND)) } catch { case NonFatal(_) => null } @@ -980,13 +991,14 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } } else { val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) - val locale = ctx.addReferenceObj("locale", Locale.US) val tf = TimestampFormatter.getClass.getName.stripSuffix("$") + val ldf = LegacyDateFormats.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (seconds, f) => { s""" try { - ${ev.value} = UTF8String.fromString($tf$$.MODULE$$.apply($f.toString(), $zid, $locale). - format($seconds * 1000000L)); + ${ev.value} = UTF8String.fromString( + $tf$$.MODULE$$.apply($f.toString(), $zid, $ldf$$.MODULE$$.SIMPLE_DATE_FORMAT()) + .format($seconds * 1000000L)); } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; }""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index cdf4b4689e821..45c4edff47070 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -88,10 +88,10 @@ private[sql] class JSONOptions( val zoneId: ZoneId = DateTimeUtils.getZoneId( parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId)) - val dateFormat: String = parameters.getOrElse("dateFormat", "uuuu-MM-dd") + val dateFormat: String = parameters.getOrElse("dateFormat", DateFormatter.defaultPattern) val timestampFormat: String = - parameters.getOrElse("timestampFormat", "uuuu-MM-dd'T'HH:mm:ss.SSSXXX") + parameters.getOrElse("timestampFormat", s"${DateFormatter.defaultPattern}'T'HH:mm:ss.SSSXXX") val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala index 9c63593ea1752..141360ff02117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala @@ -24,6 +24,7 @@ import com.fasterxml.jackson.core._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.types._ /** @@ -80,11 +81,13 @@ private[sql] class JacksonGenerator( private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private def makeWriter(dataType: DataType): ValueWriter = dataType match { case NullType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 76efa574a99ff..1e408cdb126b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -58,11 +59,13 @@ class JacksonParser( private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) private val dateFormatter = DateFormatter( options.dateFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) /** * Create a converter which converts the JSON documents held by the `JsonParser` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index f030955ee6e7f..82dd6d0da2632 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -40,7 +41,8 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable { private val timestampFormatter = TimestampFormatter( options.timestampFormat, options.zoneId, - options.locale) + options.locale, + legacyFormat = FAST_DATE_FORMAT) /** * Infer the type of a collection of json records in three stages: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala index 28189b65dee9a..2cf82d1cfa177 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.util +import java.text.SimpleDateFormat import java.time.{LocalDate, ZoneId} -import java.util.Locale +import java.util.{Date, Locale} import org.apache.commons.lang3.time.FastDateFormat @@ -51,41 +52,76 @@ class Iso8601DateFormatter( } } -class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter { - @transient - private lazy val format = FastDateFormat.getInstance(pattern, locale) +trait LegacyDateFormatter extends DateFormatter { + def parseToDate(s: String): Date + def formatDate(d: Date): String override def parse(s: String): Int = { - val milliseconds = format.parse(s).getTime + val milliseconds = parseToDate(s).getTime DateTimeUtils.millisToDays(milliseconds) } override def format(days: Int): String = { val date = DateTimeUtils.toJavaDate(days) - format.format(date) + formatDate(date) } } +class LegacyFastDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter { + @transient + private lazy val fdf = FastDateFormat.getInstance(pattern, locale) + override def parseToDate(s: String): Date = fdf.parse(s) + override def formatDate(d: Date): String = fdf.format(d) +} + +class LegacySimpleDateFormatter(pattern: String, locale: Locale) extends LegacyDateFormatter { + @transient + private lazy val sdf = new SimpleDateFormat(pattern, locale) + override def parseToDate(s: String): Date = sdf.parse(s) + override def formatDate(d: Date): String = sdf.format(d) +} + object DateFormatter { + import LegacyDateFormats._ + val defaultLocale: Locale = Locale.US - def apply(format: String, zoneId: ZoneId, locale: Locale): DateFormatter = { + def defaultPattern(): String = { + if (SQLConf.get.legacyTimeParserEnabled) "yyyy-MM-dd" else "uuuu-MM-dd" + } + + private def getFormatter( + format: Option[String], + zoneId: ZoneId, + locale: Locale = defaultLocale, + legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT): DateFormatter = { + + val pattern = format.getOrElse(defaultPattern) if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyDateFormatter(format, locale) + legacyFormat match { + case FAST_DATE_FORMAT => + new LegacyFastDateFormatter(pattern, locale) + case SIMPLE_DATE_FORMAT | LENIENT_SIMPLE_DATE_FORMAT => + new LegacySimpleDateFormatter(pattern, locale) + } } else { - new Iso8601DateFormatter(format, zoneId, locale) + new Iso8601DateFormatter(pattern, zoneId, locale) } } + def apply( + format: String, + zoneId: ZoneId, + locale: Locale, + legacyFormat: LegacyDateFormat): DateFormatter = { + getFormatter(Some(format), zoneId, locale, legacyFormat) + } + def apply(format: String, zoneId: ZoneId): DateFormatter = { - apply(format, zoneId, defaultLocale) + getFormatter(Some(format), zoneId) } def apply(zoneId: ZoneId): DateFormatter = { - if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyDateFormatter("yyyy-MM-dd", defaultLocale) - } else { - new Iso8601DateFormatter("uuuu-MM-dd", zoneId, defaultLocale) - } + getFormatter(None, zoneId) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index fe1a4fe710c20..4893a7ec91cbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -17,19 +17,20 @@ package org.apache.spark.sql.catalyst.util -import java.text.ParseException +import java.text.{ParseException, ParsePosition, SimpleDateFormat} import java.time._ import java.time.format.DateTimeParseException import java.time.temporal.ChronoField.MICRO_OF_SECOND import java.time.temporal.TemporalQueries -import java.util.{Locale, TimeZone} +import java.util.{Calendar, GregorianCalendar, Locale, TimeZone} import java.util.concurrent.TimeUnit.SECONDS import org.apache.commons.lang3.time.FastDateFormat -import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS -import org.apache.spark.sql.catalyst.util.DateTimeUtils.convertSpecialTimestamp +import org.apache.spark.sql.catalyst.util.DateTimeConstants._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{ convertSpecialTimestamp, SQLTimestamp} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.Decimal sealed trait TimestampFormatter extends Serializable { /** @@ -90,44 +91,139 @@ class FractionTimestampFormatter(zoneId: ZoneId) override protected lazy val formatter = DateTimeFormatterHelper.fractionFormatter } -class LegacyTimestampFormatter( +/** + * The custom sub-class of `GregorianCalendar` is needed to get access to + * protected `fields` immediately after parsing. We cannot use + * the `get()` method because it performs normalization of the fraction + * part. Accordingly, the `MILLISECOND` field doesn't contain original value. + * + * Also this class allows to set raw value to the `MILLISECOND` field + * directly before formatting. + */ +class MicrosCalendar(tz: TimeZone, digitsInFraction: Int) + extends GregorianCalendar(tz, Locale.US) { + // Converts parsed `MILLISECOND` field to seconds fraction in microsecond precision. + // For example if the fraction pattern is `SSSS` then `digitsInFraction` = 4, and + // if the `MILLISECOND` field was parsed to `1234`. + def getMicros(): SQLTimestamp = { + // Append 6 zeros to the field: 1234 -> 1234000000 + val d = fields(Calendar.MILLISECOND) * MICROS_PER_SECOND + // Take the first 6 digits from `d`: 1234000000 -> 123400 + // The rest contains exactly `digitsInFraction`: `0000` = 10 ^ digitsInFraction + // So, the result is `(1234 * 1000000) / (10 ^ digitsInFraction) + d / Decimal.POW_10(digitsInFraction) + } + + // Converts the seconds fraction in microsecond precision to a value + // that can be correctly formatted according to the specified fraction pattern. + // The method performs operations opposite to `getMicros()`. + def setMicros(micros: Long): Unit = { + val d = micros * Decimal.POW_10(digitsInFraction) + fields(Calendar.MILLISECOND) = (d / MICROS_PER_SECOND).toInt + } +} + +class LegacyFastTimestampFormatter( pattern: String, zoneId: ZoneId, locale: Locale) extends TimestampFormatter { - @transient private lazy val format = + @transient private lazy val fastDateFormat = FastDateFormat.getInstance(pattern, TimeZone.getTimeZone(zoneId), locale) + @transient private lazy val cal = new MicrosCalendar( + fastDateFormat.getTimeZone, + fastDateFormat.getPattern.count(_ == 'S')) + + def parse(s: String): SQLTimestamp = { + cal.clear() // Clear the calendar because it can be re-used many times + if (!fastDateFormat.parse(s, new ParsePosition(0), cal)) { + throw new IllegalArgumentException(s"'$s' is an invalid timestamp") + } + val micros = cal.getMicros() + cal.set(Calendar.MILLISECOND, 0) + cal.getTimeInMillis * MICROS_PER_MILLIS + micros + } + + def format(timestamp: SQLTimestamp): String = { + cal.setTimeInMillis(Math.floorDiv(timestamp, MICROS_PER_SECOND) * MILLIS_PER_SECOND) + cal.setMicros(Math.floorMod(timestamp, MICROS_PER_SECOND)) + fastDateFormat.format(cal) + } +} - protected def toMillis(s: String): Long = format.parse(s).getTime +class LegacySimpleTimestampFormatter( + pattern: String, + zoneId: ZoneId, + locale: Locale, + lenient: Boolean = true) extends TimestampFormatter { + @transient private lazy val sdf = { + val formatter = new SimpleDateFormat(pattern, locale) + formatter.setTimeZone(TimeZone.getTimeZone(zoneId)) + formatter.setLenient(lenient) + formatter + } - override def parse(s: String): Long = toMillis(s) * MICROS_PER_MILLIS + override def parse(s: String): Long = { + sdf.parse(s).getTime * MICROS_PER_MILLIS + } override def format(us: Long): String = { - format.format(DateTimeUtils.toJavaTimestamp(us)) + val timestamp = DateTimeUtils.toJavaTimestamp(us) + sdf.format(timestamp) } } +object LegacyDateFormats extends Enumeration { + type LegacyDateFormat = Value + val FAST_DATE_FORMAT, SIMPLE_DATE_FORMAT, LENIENT_SIMPLE_DATE_FORMAT = Value +} + object TimestampFormatter { + import LegacyDateFormats._ + val defaultLocale: Locale = Locale.US - def apply(format: String, zoneId: ZoneId, locale: Locale): TimestampFormatter = { + def defaultPattern(): String = s"${DateFormatter.defaultPattern()} HH:mm:ss" + + private def getFormatter( + format: Option[String], + zoneId: ZoneId, + locale: Locale = defaultLocale, + legacyFormat: LegacyDateFormat = LENIENT_SIMPLE_DATE_FORMAT): TimestampFormatter = { + + val pattern = format.getOrElse(defaultPattern) if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyTimestampFormatter(format, zoneId, locale) + legacyFormat match { + case FAST_DATE_FORMAT => + new LegacyFastTimestampFormatter(pattern, zoneId, locale) + case SIMPLE_DATE_FORMAT => + new LegacySimpleTimestampFormatter(pattern, zoneId, locale, lenient = false) + case LENIENT_SIMPLE_DATE_FORMAT => + new LegacySimpleTimestampFormatter(pattern, zoneId, locale, lenient = true) + } } else { - new Iso8601TimestampFormatter(format, zoneId, locale) + new Iso8601TimestampFormatter(pattern, zoneId, locale) } } + def apply( + format: String, + zoneId: ZoneId, + locale: Locale, + legacyFormat: LegacyDateFormat): TimestampFormatter = { + getFormatter(Some(format), zoneId, locale, legacyFormat) + } + + def apply(format: String, zoneId: ZoneId, legacyFormat: LegacyDateFormat): TimestampFormatter = { + getFormatter(Some(format), zoneId, defaultLocale, legacyFormat) + } + def apply(format: String, zoneId: ZoneId): TimestampFormatter = { - apply(format, zoneId, defaultLocale) + getFormatter(Some(format), zoneId) } def apply(zoneId: ZoneId): TimestampFormatter = { - if (SQLConf.get.legacyTimeParserEnabled) { - new LegacyTimestampFormatter("yyyy-MM-dd HH:mm:ss", zoneId, defaultLocale) - } else { - new Iso8601TimestampFormatter("uuuu-MM-dd HH:mm:ss", zoneId, defaultLocale) - } + getFormatter(None, zoneId) } def getFractionFormatter(zoneId: ZoneId): TimestampFormatter = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9ce64b09f7870..f32e48e1cc128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -541,7 +541,7 @@ object Decimal { /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 - private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) + val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) private val BIG_DEC_ZERO = BigDecimal(0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 274d0beebd300..f04149ab7eb29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils, Timesta import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_SECOND import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -241,41 +242,45 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("DateFormat") { - checkEvaluation( - DateFormatClass(Literal.create(null, TimestampType), Literal("y"), gmtId), - null) - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), - Literal.create(null, StringType), gmtId), null) - - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), - Literal("y"), gmtId), "2015") - checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), gmtId), "2013") - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), - Literal("H"), gmtId), "0") - checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), gmtId), "13") - - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), - Literal("y"), pstId), "2015") - checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), pstId), "2013") - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), - Literal("H"), pstId), "0") - checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), pstId), "5") - - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), - Literal("y"), jstId), "2015") - checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), jstId), "2013") - checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), - Literal("H"), jstId), "0") - checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), jstId), "22") - - // SPARK-28072 The codegen path should work - checkEvaluation( - expression = DateFormatClass( - BoundReference(ordinal = 0, dataType = TimestampType, nullable = true), - BoundReference(ordinal = 1, dataType = StringType, nullable = true), - jstId), - expected = "22", - inputRow = InternalRow(DateTimeUtils.fromJavaTimestamp(ts), UTF8String.fromString("H"))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + checkEvaluation( + DateFormatClass(Literal.create(null, TimestampType), Literal("y"), gmtId), + null) + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), + Literal.create(null, StringType), gmtId), null) + + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), + Literal("y"), gmtId), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), gmtId), "2013") + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, gmtId), + Literal("H"), gmtId), "0") + checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), gmtId), "13") + + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), + Literal("y"), pstId), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), pstId), "2013") + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, pstId), + Literal("H"), pstId), "0") + checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), pstId), "5") + + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), + Literal("y"), jstId), "2015") + checkEvaluation(DateFormatClass(Literal(ts), Literal("y"), jstId), "2013") + checkEvaluation(DateFormatClass(Cast(Literal(d), TimestampType, jstId), + Literal("H"), jstId), "0") + checkEvaluation(DateFormatClass(Literal(ts), Literal("H"), jstId), "22") + + // SPARK-28072 The codegen path should work + checkEvaluation( + expression = DateFormatClass( + BoundReference(ordinal = 0, dataType = TimestampType, nullable = true), + BoundReference(ordinal = 1, dataType = StringType, nullable = true), + jstId), + expected = "22", + inputRow = InternalRow(DateTimeUtils.fromJavaTimestamp(ts), UTF8String.fromString("H"))) + } + } } test("Hour") { @@ -705,162 +710,189 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("from_unixtime") { - val fmt1 = "yyyy-MM-dd HH:mm:ss" - val sdf1 = new SimpleDateFormat(fmt1, Locale.US) - val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { - val timeZoneId = Option(tz.getID) - sdf1.setTimeZone(tz) - sdf2.setTimeZone(tz) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val fmt1 = "yyyy-MM-dd HH:mm:ss" + val sdf1 = new SimpleDateFormat(fmt1, Locale.US) + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf1.setTimeZone(tz) + sdf2.setTimeZone(tz) - checkEvaluation( - FromUnixTime(Literal(0L), Literal(fmt1), timeZoneId), - sdf1.format(new Timestamp(0))) - checkEvaluation(FromUnixTime( - Literal(1000L), Literal(fmt1), timeZoneId), - sdf1.format(new Timestamp(1000000))) - checkEvaluation( - FromUnixTime(Literal(-1000L), Literal(fmt2), timeZoneId), - sdf2.format(new Timestamp(-1000000))) - checkEvaluation( - FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType), timeZoneId), - null) - checkEvaluation( - FromUnixTime(Literal.create(null, LongType), Literal(fmt1), timeZoneId), - null) - checkEvaluation( - FromUnixTime(Literal(1000L), Literal.create(null, StringType), timeZoneId), - null) - checkEvaluation( - FromUnixTime(Literal(0L), Literal("not a valid format"), timeZoneId), null) + checkEvaluation( + FromUnixTime(Literal(0L), Literal(fmt1), timeZoneId), + sdf1.format(new Timestamp(0))) + checkEvaluation(FromUnixTime( + Literal(1000L), Literal(fmt1), timeZoneId), + sdf1.format(new Timestamp(1000000))) + checkEvaluation( + FromUnixTime(Literal(-1000L), Literal(fmt2), timeZoneId), + sdf2.format(new Timestamp(-1000000))) + checkEvaluation( + FromUnixTime( + Literal.create(null, LongType), + Literal.create(null, StringType), timeZoneId), + null) + checkEvaluation( + FromUnixTime(Literal.create(null, LongType), Literal(fmt1), timeZoneId), + null) + checkEvaluation( + FromUnixTime(Literal(1000L), Literal.create(null, StringType), timeZoneId), + null) + checkEvaluation( + FromUnixTime(Literal(0L), Literal("not a valid format"), timeZoneId), null) - // SPARK-28072 The codegen path for non-literal input should also work - checkEvaluation( - expression = FromUnixTime( - BoundReference(ordinal = 0, dataType = LongType, nullable = true), - BoundReference(ordinal = 1, dataType = StringType, nullable = true), - timeZoneId), - expected = UTF8String.fromString(sdf1.format(new Timestamp(0))), - inputRow = InternalRow(0L, UTF8String.fromString(fmt1))) + // SPARK-28072 The codegen path for non-literal input should also work + checkEvaluation( + expression = FromUnixTime( + BoundReference(ordinal = 0, dataType = LongType, nullable = true), + BoundReference(ordinal = 1, dataType = StringType, nullable = true), + timeZoneId), + expected = UTF8String.fromString(sdf1.format(new Timestamp(0))), + inputRow = InternalRow(0L, UTF8String.fromString(fmt1))) + } + } } } test("unix_timestamp") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) - val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3, Locale.US) - sdf3.setTimeZone(TimeZoneGMT) - - withDefaultTimeZone(TimeZoneGMT) { - for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { - val timeZoneId = Option(tz.getID) - sdf1.setTimeZone(tz) - sdf2.setTimeZone(tz) - - val date1 = Date.valueOf("2015-07-24") - checkEvaluation(UnixTimestamp( - Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), 0L) - checkEvaluation(UnixTimestamp( - Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), - 1000L) - checkEvaluation( - UnixTimestamp( - Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), - 1000L) - checkEvaluation( - UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) - checkEvaluation( - UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId), - -1000L) - checkEvaluation(UnixTimestamp( - Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) - val t1 = UnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] - val t2 = UnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] - assert(t2 - t1 <= 1) - checkEvaluation( - UnixTimestamp( - Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), - null) - checkEvaluation( - UnixTimestamp(Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), - null) - checkEvaluation( - UnixTimestamp(Literal(date1), Literal.create(null, StringType), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) - checkEvaluation( - UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) + sdf3.setTimeZone(TimeZoneGMT) + + withDefaultTimeZone(TimeZoneGMT) { + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf1.setTimeZone(tz) + sdf2.setTimeZone(tz) + + val date1 = Date.valueOf("2015-07-24") + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(0))), + Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), 0L) + checkEvaluation(UnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), + Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + 1000L) + checkEvaluation( + UnixTimestamp( + Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + 1000L) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + MILLISECONDS.toSeconds( + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + checkEvaluation( + UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), + Literal(fmt2), timeZoneId), + -1000L) + checkEvaluation(UnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), + MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) + val t1 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + val t2 = UnixTimestamp( + CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation( + UnixTimestamp( + Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), + null) + checkEvaluation( + UnixTimestamp( + Literal.create(null, DateType), + Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + null) + checkEvaluation( + UnixTimestamp(Literal(date1), Literal.create(null, StringType), timeZoneId), + MILLISECONDS.toSeconds( + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + checkEvaluation( + UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) + } + } } } } test("to_unix_timestamp") { - val fmt1 = "yyyy-MM-dd HH:mm:ss" - val sdf1 = new SimpleDateFormat(fmt1, Locale.US) - val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - val fmt3 = "yy-MM-dd" - val sdf3 = new SimpleDateFormat(fmt3, Locale.US) - sdf3.setTimeZone(TimeZoneGMT) - - withDefaultTimeZone(TimeZoneGMT) { - for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { - val timeZoneId = Option(tz.getID) - sdf1.setTimeZone(tz) - sdf2.setTimeZone(tz) - - val date1 = Date.valueOf("2015-07-24") - checkEvaluation(ToUnixTimestamp( - Literal(sdf1.format(new Timestamp(0))), Literal(fmt1), timeZoneId), 0L) - checkEvaluation(ToUnixTimestamp( - Literal(sdf1.format(new Timestamp(1000000))), Literal(fmt1), timeZoneId), - 1000L) - checkEvaluation(ToUnixTimestamp( - Literal(new Timestamp(1000000)), Literal(fmt1)), - 1000L) - checkEvaluation( - ToUnixTimestamp(Literal(date1), Literal(fmt1), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) - checkEvaluation( - ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId), - -1000L) - checkEvaluation(ToUnixTimestamp( - Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) - val t1 = ToUnixTimestamp( - CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] - val t2 = ToUnixTimestamp( - CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] - assert(t2 - t1 <= 1) - checkEvaluation(ToUnixTimestamp( - Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null) - checkEvaluation( - ToUnixTimestamp( - Literal.create(null, DateType), Literal(fmt1), timeZoneId), - null) - checkEvaluation(ToUnixTimestamp( - Literal(date1), Literal.create(null, StringType), timeZoneId), - MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) - checkEvaluation( - ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val fmt1 = "yyyy-MM-dd HH:mm:ss" + val sdf1 = new SimpleDateFormat(fmt1, Locale.US) + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) + val fmt3 = "yy-MM-dd" + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) + sdf3.setTimeZone(TimeZoneGMT) + + withDefaultTimeZone(TimeZoneGMT) { + for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { + val timeZoneId = Option(tz.getID) + sdf1.setTimeZone(tz) + sdf2.setTimeZone(tz) + + val date1 = Date.valueOf("2015-07-24") + checkEvaluation(ToUnixTimestamp( + Literal(sdf1.format(new Timestamp(0))), Literal(fmt1), timeZoneId), 0L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf1.format(new Timestamp(1000000))), Literal(fmt1), timeZoneId), + 1000L) + checkEvaluation(ToUnixTimestamp( + Literal(new Timestamp(1000000)), Literal(fmt1)), + 1000L) + checkEvaluation( + ToUnixTimestamp(Literal(date1), Literal(fmt1), timeZoneId), + MILLISECONDS.toSeconds( + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + checkEvaluation( + ToUnixTimestamp( + Literal(sdf2.format(new Timestamp(-1000000))), + Literal(fmt2), timeZoneId), + -1000L) + checkEvaluation(ToUnixTimestamp( + Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), + MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) + val t1 = ToUnixTimestamp( + CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] + val t2 = ToUnixTimestamp( + CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] + assert(t2 - t1 <= 1) + checkEvaluation(ToUnixTimestamp( + Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null) + checkEvaluation( + ToUnixTimestamp( + Literal.create(null, DateType), Literal(fmt1), timeZoneId), + null) + checkEvaluation(ToUnixTimestamp( + Literal(date1), Literal.create(null, StringType), timeZoneId), + MILLISECONDS.toSeconds( + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + checkEvaluation( + ToUnixTimestamp( + Literal("2015-07-24"), + Literal("not a valid format"), timeZoneId), null) - // SPARK-28072 The codegen path for non-literal input should also work - checkEvaluation( - expression = ToUnixTimestamp( - BoundReference(ordinal = 0, dataType = StringType, nullable = true), - BoundReference(ordinal = 1, dataType = StringType, nullable = true), - timeZoneId), - expected = 0L, - inputRow = InternalRow( - UTF8String.fromString(sdf1.format(new Timestamp(0))), UTF8String.fromString(fmt1))) + // SPARK-28072 The codegen path for non-literal input should also work + checkEvaluation( + expression = ToUnixTimestamp( + BoundReference(ordinal = 0, dataType = StringType, nullable = true), + BoundReference(ordinal = 1, dataType = StringType, nullable = true), + timeZoneId), + expected = 0L, + inputRow = InternalRow( + UTF8String.fromString(sdf1.format(new Timestamp(0))), UTF8String.fromString(fmt1))) + } + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d125581857e0b..2d5504ac00ffa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} +import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf @@ -2881,7 +2882,7 @@ object functions { * @since 1.5.0 */ def from_unixtime(ut: Column): Column = withExpr { - FromUnixTime(ut.expr, Literal("uuuu-MM-dd HH:mm:ss")) + FromUnixTime(ut.expr, Literal(TimestampFormatter.defaultPattern)) } /** @@ -2913,7 +2914,7 @@ object functions { * @since 1.5.0 */ def unix_timestamp(): Column = withExpr { - UnixTimestamp(CurrentTimestamp(), Literal("uuuu-MM-dd HH:mm:ss")) + UnixTimestamp(CurrentTimestamp(), Literal(TimestampFormatter.defaultPattern)) } /** @@ -2927,7 +2928,7 @@ object functions { * @since 1.5.0 */ def unix_timestamp(s: Column): Column = withExpr { - UnixTimestamp(s.expr, Literal("uuuu-MM-dd HH:mm:ss")) + UnixTimestamp(s.expr, Literal(TimestampFormatter.defaultPattern)) } /** diff --git a/sql/core/src/test/resources/test-data/bad_after_good.csv b/sql/core/src/test/resources/test-data/bad_after_good.csv index 4621a7d23714d..1a7c2651a11a7 100644 --- a/sql/core/src/test/resources/test-data/bad_after_good.csv +++ b/sql/core/src/test/resources/test-data/bad_after_good.csv @@ -1,2 +1,2 @@ "good record",1999-08-01 -"bad record",1999-088-01 +"bad record",1999-088_01 diff --git a/sql/core/src/test/resources/test-data/value-malformed.csv b/sql/core/src/test/resources/test-data/value-malformed.csv index 8945ed73d2e83..6e6f08fca6df8 100644 --- a/sql/core/src/test/resources/test-data/value-malformed.csv +++ b/sql/core/src/test/resources/test-data/value-malformed.csv @@ -1,2 +1,2 @@ -0,2013-111-11 12:13:14 +0,2013-111_11 12:13:14 1,1983-08-04 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index bb8cdf3cb6de1..41d53c959ef99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -96,15 +96,19 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { } test("date format") { - val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val df = Seq((d, sdf.format(d), ts)).toDF("a", "b", "c") - checkAnswer( - df.select(date_format($"a", "y"), date_format($"b", "y"), date_format($"c", "y")), - Row("2015", "2015", "2013")) + checkAnswer( + df.select(date_format($"a", "y"), date_format($"b", "y"), date_format($"c", "y")), + Row("2015", "2015", "2013")) - checkAnswer( - df.selectExpr("date_format(a, 'y')", "date_format(b, 'y')", "date_format(c, 'y')"), - Row("2015", "2015", "2013")) + checkAnswer( + df.selectExpr("date_format(a, 'y')", "date_format(b, 'y')", "date_format(c, 'y')"), + Row("2015", "2015", "2013")) + } + } } test("year") { @@ -525,170 +529,194 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { } test("from_unixtime") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) - val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" - val sdf2 = new SimpleDateFormat(fmt2, Locale.US) - val fmt3 = "yy-MM-dd HH-mm-ss" - val sdf3 = new SimpleDateFormat(fmt3, Locale.US) - val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") - checkAnswer( - df.select(from_unixtime(col("a"))), - Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) - checkAnswer( - df.select(from_unixtime(col("a"), fmt2)), - Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) - checkAnswer( - df.select(from_unixtime(col("a"), fmt3)), - Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) - checkAnswer( - df.selectExpr("from_unixtime(a)"), - Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) - checkAnswer( - df.selectExpr(s"from_unixtime(a, '$fmt2')"), - Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) - checkAnswer( - df.selectExpr(s"from_unixtime(a, '$fmt3')"), - Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" + val sdf2 = new SimpleDateFormat(fmt2, Locale.US) + val fmt3 = "yy-MM-dd HH-mm-ss" + val sdf3 = new SimpleDateFormat(fmt3, Locale.US) + val df = Seq((1000, "yyyy-MM-dd HH:mm:ss.SSS"), (-1000, "yy-MM-dd HH-mm-ss")).toDF("a", "b") + checkAnswer( + df.select(from_unixtime(col("a"))), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt2)), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.select(from_unixtime(col("a"), fmt3)), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr("from_unixtime(a)"), + Seq(Row(sdf1.format(new Timestamp(1000000))), Row(sdf1.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt2')"), + Seq(Row(sdf2.format(new Timestamp(1000000))), Row(sdf2.format(new Timestamp(-1000000))))) + checkAnswer( + df.selectExpr(s"from_unixtime(a, '$fmt3')"), + Seq(Row(sdf3.format(new Timestamp(1000000))), Row(sdf3.format(new Timestamp(-1000000))))) + } + } } private def secs(millis: Long): Long = TimeUnit.MILLISECONDS.toSeconds(millis) test("unix_timestamp") { - val date1 = Date.valueOf("2015-07-24") - val date2 = Date.valueOf("2015-07-25") - val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") - val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") - val s1 = "2015/07/24 10:00:00.5" - val s2 = "2015/07/25 02:02:02.6" - val ss1 = "2015-07-24 10:00:00" - val ss2 = "2015-07-25 02:02:02" - val fmt = "yyyy/MM/dd HH:mm:ss.S" - val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") - checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( - Row(secs(date1.getTime)), Row(secs(date2.getTime)))) - checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( - Row(secs(date1.getTime)), Row(secs(date2.getTime)))) - checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - - val x1 = "2015-07-24 10:00:00" - val x2 = "2015-25-07 02:02:02" - val x3 = "2015-07-24 25:02:02" - val x4 = "2015-24-07 26:02:02" - val ts3 = Timestamp.valueOf("2015-07-24 02:25:02") - val ts4 = Timestamp.valueOf("2015-07-24 00:10:00") - - val df1 = Seq(x1, x2, x3, x4).toDF("x") - checkAnswer(df1.select(unix_timestamp(col("x"))), Seq( - Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) - checkAnswer(df1.selectExpr("unix_timestamp(x)"), Seq( - Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) - checkAnswer(df1.select(unix_timestamp(col("x"), "yyyy-dd-MM HH:mm:ss")), Seq( - Row(null), Row(secs(ts2.getTime)), Row(null), Row(null))) - checkAnswer(df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd mm:HH:ss')"), Seq( - Row(secs(ts4.getTime)), Row(null), Row(secs(ts3.getTime)), Row(null))) - - // invalid format - checkAnswer(df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd aa:HH:ss')"), Seq( - Row(null), Row(null), Row(null), Row(null))) - - // february - val y1 = "2016-02-29" - val y2 = "2017-02-29" - val ts5 = Timestamp.valueOf("2016-02-29 00:00:00") - val df2 = Seq(y1, y2).toDF("y") - checkAnswer(df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")), Seq( - Row(secs(ts5.getTime)), Row(null))) - - val now = sql("select unix_timestamp()").collect().head.getLong(0) - checkAnswer( - sql(s"select cast ($now as timestamp)"), - Row(new java.util.Date(TimeUnit.SECONDS.toMillis(now)))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.select(unix_timestamp(col("ts"))), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.select(unix_timestamp(col("ss"))), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.select(unix_timestamp(col("d"), fmt)), Seq( + Row(secs(date1.getTime)), Row(secs(date2.getTime)))) + checkAnswer(df.select(unix_timestamp(col("s"), fmt)), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr("unix_timestamp(ts)"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr("unix_timestamp(ss)"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr(s"unix_timestamp(d, '$fmt')"), Seq( + Row(secs(date1.getTime)), Row(secs(date2.getTime)))) + checkAnswer(df.selectExpr(s"unix_timestamp(s, '$fmt')"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + + val x1 = "2015-07-24 10:00:00" + val x2 = "2015-25-07 02:02:02" + val x3 = "2015-07-24 25:02:02" + val x4 = "2015-24-07 26:02:02" + val ts3 = Timestamp.valueOf("2015-07-24 02:25:02") + val ts4 = Timestamp.valueOf("2015-07-24 00:10:00") + + val df1 = Seq(x1, x2, x3, x4).toDF("x") + checkAnswer(df1.select(unix_timestamp(col("x"))), Seq( + Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) + checkAnswer(df1.selectExpr("unix_timestamp(x)"), Seq( + Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) + checkAnswer(df1.select(unix_timestamp(col("x"), "yyyy-dd-MM HH:mm:ss")), Seq( + Row(null), Row(secs(ts2.getTime)), Row(null), Row(null))) + checkAnswer(df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd mm:HH:ss')"), Seq( + Row(secs(ts4.getTime)), Row(null), Row(secs(ts3.getTime)), Row(null))) + + // invalid format + checkAnswer(df1.selectExpr(s"unix_timestamp(x, 'yyyy-MM-dd aa:HH:ss')"), Seq( + Row(null), Row(null), Row(null), Row(null))) + + // february + val y1 = "2016-02-29" + val y2 = "2017-02-29" + val ts5 = Timestamp.valueOf("2016-02-29 00:00:00") + val df2 = Seq(y1, y2).toDF("y") + checkAnswer(df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")), Seq( + Row(secs(ts5.getTime)), Row(null))) + + val now = sql("select unix_timestamp()").collect().head.getLong(0) + checkAnswer( + sql(s"select cast ($now as timestamp)"), + Row(new java.util.Date(TimeUnit.SECONDS.toMillis(now)))) + } + } } test("to_unix_timestamp") { - val date1 = Date.valueOf("2015-07-24") - val date2 = Date.valueOf("2015-07-25") - val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") - val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") - val s1 = "2015/07/24 10:00:00.5" - val s2 = "2015/07/25 02:02:02.6" - val ss1 = "2015-07-24 10:00:00" - val ss2 = "2015-07-25 02:02:02" - val fmt = "yyyy/MM/dd HH:mm:ss.S" - val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") - checkAnswer(df.selectExpr("to_unix_timestamp(ts)"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr("to_unix_timestamp(ss)"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - checkAnswer(df.selectExpr(s"to_unix_timestamp(d, '$fmt')"), Seq( - Row(secs(date1.getTime)), Row(secs(date2.getTime)))) - checkAnswer(df.selectExpr(s"to_unix_timestamp(s, '$fmt')"), Seq( - Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) - - val x1 = "2015-07-24 10:00:00" - val x2 = "2015-25-07 02:02:02" - val x3 = "2015-07-24 25:02:02" - val x4 = "2015-24-07 26:02:02" - val ts3 = Timestamp.valueOf("2015-07-24 02:25:02") - val ts4 = Timestamp.valueOf("2015-07-24 00:10:00") - - val df1 = Seq(x1, x2, x3, x4).toDF("x") - checkAnswer(df1.selectExpr("to_unix_timestamp(x)"), Seq( - Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) - checkAnswer(df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd mm:HH:ss')"), Seq( - Row(secs(ts4.getTime)), Row(null), Row(secs(ts3.getTime)), Row(null))) - - // february - val y1 = "2016-02-29" - val y2 = "2017-02-29" - val ts5 = Timestamp.valueOf("2016-02-29 00:00:00") - val df2 = Seq(y1, y2).toDF("y") - checkAnswer(df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")), Seq( - Row(secs(ts5.getTime)), Row(null))) - - // invalid format - checkAnswer(df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd bb:HH:ss')"), Seq( - Row(null), Row(null), Row(null), Row(null))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00.3") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02.2") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + checkAnswer(df.selectExpr("to_unix_timestamp(ts)"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr("to_unix_timestamp(ss)"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(d, '$fmt')"), Seq( + Row(secs(date1.getTime)), Row(secs(date2.getTime)))) + checkAnswer(df.selectExpr(s"to_unix_timestamp(s, '$fmt')"), Seq( + Row(secs(ts1.getTime)), Row(secs(ts2.getTime)))) + + val x1 = "2015-07-24 10:00:00" + val x2 = "2015-25-07 02:02:02" + val x3 = "2015-07-24 25:02:02" + val x4 = "2015-24-07 26:02:02" + val ts3 = Timestamp.valueOf("2015-07-24 02:25:02") + val ts4 = Timestamp.valueOf("2015-07-24 00:10:00") + + val df1 = Seq(x1, x2, x3, x4).toDF("x") + checkAnswer(df1.selectExpr("to_unix_timestamp(x)"), Seq( + Row(secs(ts1.getTime)), Row(null), Row(null), Row(null))) + checkAnswer(df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd mm:HH:ss')"), Seq( + Row(secs(ts4.getTime)), Row(null), Row(secs(ts3.getTime)), Row(null))) + + // february + val y1 = "2016-02-29" + val y2 = "2017-02-29" + val ts5 = Timestamp.valueOf("2016-02-29 00:00:00") + val df2 = Seq(y1, y2).toDF("y") + checkAnswer(df2.select(unix_timestamp(col("y"), "yyyy-MM-dd")), Seq( + Row(secs(ts5.getTime)), Row(null))) + + // invalid format + checkAnswer(df1.selectExpr(s"to_unix_timestamp(x, 'yyyy-MM-dd bb:HH:ss')"), Seq( + Row(null), Row(null), Row(null), Row(null))) + } + } } test("to_timestamp") { - val date1 = Date.valueOf("2015-07-24") - val date2 = Date.valueOf("2015-07-25") - val ts_date1 = Timestamp.valueOf("2015-07-24 00:00:00") - val ts_date2 = Timestamp.valueOf("2015-07-25 00:00:00") - val ts1 = Timestamp.valueOf("2015-07-24 10:00:00") - val ts2 = Timestamp.valueOf("2015-07-25 02:02:02") - val s1 = "2015/07/24 10:00:00.5" - val s2 = "2015/07/25 02:02:02.6" - val ts1m = Timestamp.valueOf("2015-07-24 10:00:00.5") - val ts2m = Timestamp.valueOf("2015-07-25 02:02:02.6") - val ss1 = "2015-07-24 10:00:00" - val ss2 = "2015-07-25 02:02:02" - val fmt = "yyyy/MM/dd HH:mm:ss.S" - val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") - - checkAnswer(df.select(to_timestamp(col("ss"))), - df.select(unix_timestamp(col("ss")).cast("timestamp"))) - checkAnswer(df.select(to_timestamp(col("ss"))), Seq( - Row(ts1), Row(ts2))) - checkAnswer(df.select(to_timestamp(col("s"), fmt)), Seq( - Row(ts1m), Row(ts2m))) - checkAnswer(df.select(to_timestamp(col("ts"), fmt)), Seq( - Row(ts1), Row(ts2))) - checkAnswer(df.select(to_timestamp(col("d"), "yyyy-MM-dd")), Seq( - Row(ts_date1), Row(ts_date2))) + Seq(false, true).foreach { legacyParser => + withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) { + val date1 = Date.valueOf("2015-07-24") + val date2 = Date.valueOf("2015-07-25") + val ts_date1 = Timestamp.valueOf("2015-07-24 00:00:00") + val ts_date2 = Timestamp.valueOf("2015-07-25 00:00:00") + val ts1 = Timestamp.valueOf("2015-07-24 10:00:00") + val ts2 = Timestamp.valueOf("2015-07-25 02:02:02") + val s1 = "2015/07/24 10:00:00.5" + val s2 = "2015/07/25 02:02:02.6" + val ts1m = Timestamp.valueOf("2015-07-24 10:00:00.5") + val ts2m = Timestamp.valueOf("2015-07-25 02:02:02.6") + val ss1 = "2015-07-24 10:00:00" + val ss2 = "2015-07-25 02:02:02" + val fmt = "yyyy/MM/dd HH:mm:ss.S" + val df = Seq((date1, ts1, s1, ss1), (date2, ts2, s2, ss2)).toDF("d", "ts", "s", "ss") + + checkAnswer(df.select(to_timestamp(col("ss"))), + df.select(unix_timestamp(col("ss")).cast("timestamp"))) + checkAnswer(df.select(to_timestamp(col("ss"))), Seq( + Row(ts1), Row(ts2))) + if (legacyParser) { + // In Spark 2.4 and earlier, to_timestamp() parses in seconds precision and cuts off + // the fractional part of seconds. The behavior was changed by SPARK-27438. + val legacyFmt = "yyyy/MM/dd HH:mm:ss" + checkAnswer(df.select(to_timestamp(col("s"), legacyFmt)), Seq( + Row(ts1), Row(ts2))) + } else { + checkAnswer(df.select(to_timestamp(col("s"), fmt)), Seq( + Row(ts1m), Row(ts2m))) + } + checkAnswer(df.select(to_timestamp(col("ts"), fmt)), Seq( + Row(ts1), Row(ts2))) + checkAnswer(df.select(to_timestamp(col("d"), "yyyy-MM-dd")), Seq( + Row(ts_date1), Row(ts_date2))) + } + } } test("datediff") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 97dfbbdb7fd2f..b1105b4a63bba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1182,7 +1182,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa .schema(schemaWithCorrField1) .csv(testFile(valueMalformedFile)) checkAnswer(df2, - Row(0, null, "0,2013-111-11 12:13:14") :: + Row(0, null, "0,2013-111_11 12:13:14") :: Row(1, java.sql.Date.valueOf("1983-08-04"), null) :: Nil) @@ -1199,7 +1199,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa .schema(schemaWithCorrField2) .csv(testFile(valueMalformedFile)) checkAnswer(df3, - Row(0, "0,2013-111-11 12:13:14", null) :: + Row(0, "0,2013-111_11 12:13:14", null) :: Row(1, null, java.sql.Date.valueOf("1983-08-04")) :: Nil) @@ -1435,7 +1435,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa assert(df.filter($"_corrupt_record".isNull).count() == 1) checkAnswer( df.select(columnNameOfCorruptRecord), - Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil + Row("0,2013-111_11 12:13:14") :: Row(null) :: Nil ) } @@ -2093,7 +2093,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa Seq("csv", "").foreach { reader => withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> reader) { withTempPath { path => - val df = Seq(("0", "2013-111-11")).toDF("a", "b") + val df = Seq(("0", "2013-111_11")).toDF("a", "b") df.write .option("header", "true") .csv(path.getAbsolutePath) @@ -2109,7 +2109,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord) .schema(schemaWithCorrField) .csv(path.getAbsoluteFile.toString) - checkAnswer(readDF, Row(0, null, "0,2013-111-11") :: Nil) + checkAnswer(readDF, Row(0, null, "0,2013-111_11") :: Nil) } } } @@ -2216,7 +2216,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa val readback = spark.read .option("mode", mode) .option("header", true) - .option("timestampFormat", "uuuu-MM-dd HH:mm:ss") + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss") .option("multiLine", multiLine) .schema("c0 string, c1 integer, c2 timestamp") .csv(path.getAbsolutePath) @@ -2235,7 +2235,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa } test("filters push down - malformed input in PERMISSIVE mode") { - val invalidTs = "2019-123-14 20:35:30" + val invalidTs = "2019-123_14 20:35:30" val invalidRow = s"0,$invalidTs,999" val validTs = "2019-12-14 20:35:30" Seq(true, false).foreach { filterPushdown => @@ -2252,7 +2252,7 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa .option("mode", "PERMISSIVE") .option("columnNameOfCorruptRecord", "c3") .option("header", true) - .option("timestampFormat", "uuuu-MM-dd HH:mm:ss") + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss") .schema("c0 integer, c1 timestamp, c2 integer, c3 string") .csv(path.getAbsolutePath) .where(condition) @@ -2309,3 +2309,10 @@ class CSVv2Suite extends CSVSuite { .sparkConf .set(SQLConf.USE_V1_SOURCE_LIST, "") } + +class CSVLegacyTimeParserSuite extends CSVSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.LEGACY_TIME_PARSER_ENABLED, true) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index b20da2266b0f3..7abe818a29d9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2572,3 +2572,10 @@ class JsonV2Suite extends JsonSuite { .sparkConf .set(SQLConf.USE_V1_SOURCE_LIST, "") } + +class JsonLegacyTimeParserSuite extends JsonSuite { + override protected def sparkConf: SparkConf = + super + .sparkConf + .set(SQLConf.LEGACY_TIME_PARSER_ENABLED, true) +} From 61b1e608f07afd965028313c13bf89c19b006312 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Wed, 12 Feb 2020 23:50:34 +0800 Subject: [PATCH 025/185] [SPARK-30759][SQL][TESTS][FOLLOWUP] Check cache initialization in StringRegexExpression ### What changes were proposed in this pull request? Added new test to `RegexpExpressionsSuite` which checks that `cache` of compiled pattern is set when the `right` expression (pattern in `LIKE`) is a foldable expression. ### Why are the changes needed? To be sure that `cache` in `StringRegexExpression` is initialized for foldable patterns. ### Does this PR introduce any user-facing change? No ### How was this patch tested? By running the added test in `RegexpExpressionsSuite`. Closes #27547 from MaxGekk/regexp-cache-test. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/RegexpExpressionsSuite.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 86da62bc74940..712d2bc4c4736 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -329,4 +329,12 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringSplit(s1, s2, -1), null, row3) } + test("SPARK-30759: cache initialization for literal patterns") { + val expr = "A" like Literal.create("a", StringType) + expr.eval() + val cache = expr.getClass.getSuperclass + .getDeclaredFields.filter(_.getName.endsWith("cache")).head + cache.setAccessible(true) + assert(cache.get(expr).asInstanceOf[java.util.regex.Pattern].pattern().contains("a")) + } } From 5919bd3b8d3ef3c3e957d8e3e245e00383b979bf Mon Sep 17 00:00:00 2001 From: Eric Wu <492960551@qq.com> Date: Thu, 13 Feb 2020 02:00:23 +0800 Subject: [PATCH 026/185] [SPARK-30651][SQL] Add detailed information for Aggregate operators in EXPLAIN FORMATTED ### What changes were proposed in this pull request? Currently `EXPLAIN FORMATTED` only report input attributes of HashAggregate/ObjectHashAggregate/SortAggregate, while `EXPLAIN EXTENDED` provides more information of Keys, Functions, etc. This PR enhanced `EXPLAIN FORMATTED` to sync with original explain behavior. ### Why are the changes needed? The newly added `EXPLAIN FORMATTED` got less information comparing to the original `EXPLAIN EXTENDED` ### Does this PR introduce any user-facing change? Yes, taking HashAggregate explain result as example. **SQL** ``` EXPLAIN FORMATTED SELECT COUNT(val) + SUM(key) as TOTAL, COUNT(key) FILTER (WHERE val > 1) FROM explain_temp1; ``` **EXPLAIN EXTENDED** ``` == Physical Plan == *(2) HashAggregate(keys=[], functions=[count(val#6), sum(cast(key#5 as bigint)), count(key#5)], output=[TOTAL#62L, count(key) FILTER (WHERE (val > 1))#71L]) +- Exchange SinglePartition, true, [id=#89] +- HashAggregate(keys=[], functions=[partial_count(val#6), partial_sum(cast(key#5 as bigint)), partial_count(key#5) FILTER (WHERE (val#6 > 1))], output=[count#75L, sum#76L, count#77L]) +- *(1) ColumnarToRow +- FileScan parquet default.explain_temp1[key#5,val#6] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex[file:/Users/XXX/spark-dev/spark/spark-warehouse/explain_temp1], PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` **EXPLAIN FORMATTED - BEFORE** ``` == Physical Plan == * HashAggregate (5) +- Exchange (4) +- HashAggregate (3) +- * ColumnarToRow (2) +- Scan parquet default.explain_temp1 (1) ... ... (5) HashAggregate [codegen id : 2] Input: [count#91L, sum#92L, count#93L] ... ... ``` **EXPLAIN FORMATTED - AFTER** ``` == Physical Plan == * HashAggregate (5) +- Exchange (4) +- HashAggregate (3) +- * ColumnarToRow (2) +- Scan parquet default.explain_temp1 (1) ... ... (5) HashAggregate [codegen id : 2] Input: [count#91L, sum#92L, count#93L] Keys: [] Functions: [count(val#6), sum(cast(key#5 as bigint)), count(key#5)] Results: [(count(val#6)#84L + sum(cast(key#5 as bigint))#85L) AS TOTAL#78L, count(key#5)#86L AS count(key) FILTER (WHERE (val > 1))#87L] Output: [TOTAL#78L, count(key) FILTER (WHERE (val > 1))#87L] ... ... ``` ### How was this patch tested? Three tests added in explain.sql for HashAggregate/ObjectHashAggregate/SortAggregate. Closes #27368 from Eric5553/ExplainFormattedAgg. Authored-by: Eric Wu <492960551@qq.com> Signed-off-by: Wenchen Fan --- .../aggregate/BaseAggregateExec.scala | 48 ++++ .../aggregate/HashAggregateExec.scala | 2 +- .../aggregate/ObjectHashAggregateExec.scala | 2 +- .../aggregate/SortAggregateExec.scala | 4 +- .../resources/sql-tests/inputs/explain.sql | 22 +- .../sql-tests/results/explain.sql.out | 232 +++++++++++++++++- 6 files changed, 300 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala new file mode 100644 index 0000000000000..0eaa0f53fdacd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode} + +/** + * Holds common logic for aggregate operators + */ +trait BaseAggregateExec extends UnaryExecNode { + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def resultExpressions: Seq[NamedExpression] + + override def verboseStringWithOperatorId(): String = { + val inputString = child.output.mkString("[", ", ", "]") + val keyString = groupingExpressions.mkString("[", ", ", "]") + val functionString = aggregateExpressions.mkString("[", ", ", "]") + val aggregateAttributeString = aggregateAttributes.mkString("[", ", ", "]") + val resultString = resultExpressions.mkString("[", ", ", "]") + s""" + |(${ExplainUtils.getOpId(this)}) $nodeName ${ExplainUtils.getCodegenId(this)} + |Input: $inputString + |Keys: $keyString + |Functions: $functionString + |Aggregate Attributes: $aggregateAttributeString + |Results: $resultString + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f73e214a6b41f..7a26fd7a8541a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -53,7 +53,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { + extends BaseAggregateExec with BlockingOperatorWithCodegen with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 4376f6b6edd57..3fb58eb2cc8ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -67,7 +67,7 @@ case class ObjectHashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with AliasAwareOutputPartitioning { + extends BaseAggregateExec with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index b6e684e62ea5c..77ed469016fa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{AliasAwareOutputPartitioning, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -38,7 +38,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with AliasAwareOutputPartitioning { + extends BaseAggregateExec with AliasAwareOutputPartitioning { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) diff --git a/sql/core/src/test/resources/sql-tests/inputs/explain.sql b/sql/core/src/test/resources/sql-tests/inputs/explain.sql index d5253e3daddb0..497b61c6134a2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/explain.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/explain.sql @@ -5,6 +5,7 @@ CREATE table explain_temp1 (key int, val int) USING PARQUET; CREATE table explain_temp2 (key int, val int) USING PARQUET; CREATE table explain_temp3 (key int, val int) USING PARQUET; +CREATE table explain_temp4 (key int, val string) USING PARQUET; SET spark.sql.codegen.wholeStage = true; @@ -61,7 +62,7 @@ EXPLAIN FORMATTED FROM explain_temp2 WHERE val > 0) OR - key = (SELECT max(key) + key = (SELECT avg(key) FROM explain_temp3 WHERE val > 0); @@ -93,6 +94,25 @@ EXPLAIN FORMATTED CREATE VIEW explain_view AS SELECT key, val FROM explain_temp1; +-- HashAggregate +EXPLAIN FORMATTED + SELECT + COUNT(val) + SUM(key) as TOTAL, + COUNT(key) FILTER (WHERE val > 1) + FROM explain_temp1; + +-- ObjectHashAggregate +EXPLAIN FORMATTED + SELECT key, sort_array(collect_set(val))[0] + FROM explain_temp4 + GROUP BY key; + +-- SortAggregate +EXPLAIN FORMATTED + SELECT key, MIN(val) + FROM explain_temp4 + GROUP BY key; + -- cleanup DROP TABLE explain_temp1; DROP TABLE explain_temp2; diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 756c14f28a657..bc28d7f87bf00 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 18 +-- Number of queries: 22 -- !query @@ -26,6 +26,14 @@ struct<> +-- !query +CREATE table explain_temp4 (key int, val string) USING PARQUET +-- !query schema +struct<> +-- !query output + + + -- !query SET spark.sql.codegen.wholeStage = true -- !query schema @@ -76,12 +84,20 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_max(val#x)] +Aggregate Attributes: [max#x] +Results: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x] (8) Exchange Input: [key#x, max(val)#x] @@ -132,12 +148,20 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_max(val#x)] +Aggregate Attributes: [max#x] +Results: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 2] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x, max(val#x)#x AS max(val#x)#x] (8) Filter [codegen id : 2] Input : [key#x, max(val)#x, max(val#x)#x] @@ -211,12 +235,20 @@ Input : [key#x, val#x] (10) HashAggregate [codegen id : 3] Input: [key#x, val#x] +Keys: [key#x, val#x] +Functions: [] +Aggregate Attributes: [] +Results: [key#x, val#x] (11) Exchange Input: [key#x, val#x] (12) HashAggregate [codegen id : 4] Input: [key#x, val#x] +Keys: [key#x, val#x] +Functions: [] +Aggregate Attributes: [] +Results: [key#x, val#x] -- !query @@ -413,12 +445,20 @@ Input : [key#x, val#x] (9) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_max(key#x)] +Aggregate Attributes: [max#x] +Results: [max#x] (10) Exchange Input: [max#x] (11) HashAggregate [codegen id : 2] Input: [max#x] +Keys: [] +Functions: [max(key#x)] +Aggregate Attributes: [max(key#x)#x] +Results: [max(key#x)#x AS max(key)#x] Subquery:2 Hosting operator id = 7 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (18) @@ -450,12 +490,20 @@ Input : [key#x, val#x] (16) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_max(key#x)] +Aggregate Attributes: [max#x] +Results: [max#x] (17) Exchange Input: [max#x] (18) HashAggregate [codegen id : 2] Input: [max#x] +Keys: [] +Functions: [max(key#x)] +Aggregate Attributes: [max(key#x)#x] +Results: [max(key#x)#x AS max(key)#x] -- !query @@ -466,7 +514,7 @@ EXPLAIN FORMATTED FROM explain_temp2 WHERE val > 0) OR - key = (SELECT max(key) + key = (SELECT avg(key) FROM explain_temp3 WHERE val > 0) -- !query schema @@ -489,7 +537,7 @@ Input: [key#x, val#x] (3) Filter [codegen id : 1] Input : [key#x, val#x] -Condition : ((key#x = Subquery scalar-subquery#x, [id=#x]) OR (key#x = Subquery scalar-subquery#x, [id=#x])) +Condition : ((key#x = Subquery scalar-subquery#x, [id=#x]) OR (cast(key#x as double) = Subquery scalar-subquery#x, [id=#x])) ===== Subqueries ===== @@ -523,12 +571,20 @@ Input : [key#x, val#x] (8) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_max(key#x)] +Aggregate Attributes: [max#x] +Results: [max#x] (9) Exchange Input: [max#x] (10) HashAggregate [codegen id : 2] Input: [max#x] +Keys: [] +Functions: [max(key#x)] +Aggregate Attributes: [max(key#x)#x] +Results: [max(key#x)#x AS max(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = Subquery scalar-subquery#x, [id=#x] * HashAggregate (17) @@ -560,12 +616,20 @@ Input : [key#x, val#x] (15) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_avg(cast(key#x as bigint))] +Aggregate Attributes: [sum#x, count#xL] +Results: [sum#x, count#xL] (16) Exchange -Input: [max#x] +Input: [sum#x, count#xL] (17) HashAggregate [codegen id : 2] -Input: [max#x] +Input: [sum#x, count#xL] +Keys: [] +Functions: [avg(cast(key#x as bigint))] +Aggregate Attributes: [avg(cast(key#x as bigint))#x] +Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] -- !query @@ -615,12 +679,20 @@ Input: [key#x] (6) HashAggregate [codegen id : 1] Input: [key#x] +Keys: [] +Functions: [partial_avg(cast(key#x as bigint))] +Aggregate Attributes: [sum#x, count#xL] +Results: [sum#x, count#xL] (7) Exchange Input: [sum#x, count#xL] (8) HashAggregate [codegen id : 2] Input: [sum#x, count#xL] +Keys: [] +Functions: [avg(cast(key#x as bigint))] +Aggregate Attributes: [avg(cast(key#x as bigint))#x] +Results: [avg(cast(key#x as bigint))#x AS avg(key)#x] Subquery:2 Hosting operator id = 3 Hosting Expression = ReusedSubquery Subquery scalar-subquery#x, [id=#x] @@ -740,18 +812,30 @@ Input : [key#x, val#x] (5) HashAggregate [codegen id : 1] Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_max(val#x)] +Aggregate Attributes: [max#x] +Results: [key#x, max#x] (6) Exchange Input: [key#x, max#x] (7) HashAggregate [codegen id : 4] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x] (8) ReusedExchange [Reuses operator id: 6] Output : ArrayBuffer(key#x, max#x) (9) HashAggregate [codegen id : 3] Input: [key#x, max#x] +Keys: [key#x] +Functions: [max(val#x)] +Aggregate Attributes: [max(val#x)#x] +Results: [key#x, max(val#x)#x AS max(val)#x] (10) BroadcastExchange Input: [key#x, max(val)#x] @@ -786,6 +870,144 @@ Output: [] (4) Project +-- !query +EXPLAIN FORMATTED + SELECT + COUNT(val) + SUM(key) as TOTAL, + COUNT(key) FILTER (WHERE val > 1) + FROM explain_temp1 +-- !query schema +struct +-- !query output +== Physical Plan == +* HashAggregate (5) ++- Exchange (4) + +- HashAggregate (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp1 (1) + + +(1) Scan parquet default.explain_temp1 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp1] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) HashAggregate +Input: [key#x, val#x] +Keys: [] +Functions: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))] +Aggregate Attributes: [count#xL, sum#xL, count#xL] +Results: [count#xL, sum#xL, count#xL] + +(4) Exchange +Input: [count#xL, sum#xL, count#xL] + +(5) HashAggregate [codegen id : 2] +Input: [count#xL, sum#xL, count#xL] +Keys: [] +Functions: [count(val#x), sum(cast(key#x as bigint)), count(key#x)] +Aggregate Attributes: [count(val#x)#xL, sum(cast(key#x as bigint))#xL, count(key#x)#xL] +Results: [(count(val#x)#xL + sum(cast(key#x as bigint))#xL) AS TOTAL#xL, count(key#x)#xL AS count(key) FILTER (WHERE (val > 1))#xL] + + +-- !query +EXPLAIN FORMATTED + SELECT key, sort_array(collect_set(val))[0] + FROM explain_temp4 + GROUP BY key +-- !query schema +struct +-- !query output +== Physical Plan == +ObjectHashAggregate (5) ++- Exchange (4) + +- ObjectHashAggregate (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp4 (1) + + +(1) Scan parquet default.explain_temp4 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp4] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) ObjectHashAggregate +Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_collect_set(val#x, 0, 0)] +Aggregate Attributes: [buf#x] +Results: [key#x, buf#x] + +(4) Exchange +Input: [key#x, buf#x] + +(5) ObjectHashAggregate +Input: [key#x, buf#x] +Keys: [key#x] +Functions: [collect_set(val#x, 0, 0)] +Aggregate Attributes: [collect_set(val#x, 0, 0)#x] +Results: [key#x, sort_array(collect_set(val#x, 0, 0)#x, true)[0] AS sort_array(collect_set(val), true)[0]#x] + + +-- !query +EXPLAIN FORMATTED + SELECT key, MIN(val) + FROM explain_temp4 + GROUP BY key +-- !query schema +struct +-- !query output +== Physical Plan == +SortAggregate (7) ++- * Sort (6) + +- Exchange (5) + +- SortAggregate (4) + +- * Sort (3) + +- * ColumnarToRow (2) + +- Scan parquet default.explain_temp4 (1) + + +(1) Scan parquet default.explain_temp4 +Output: [key#x, val#x] +Batched: true +Location [not included in comparison]/{warehouse_dir}/explain_temp4] +ReadSchema: struct + +(2) ColumnarToRow [codegen id : 1] +Input: [key#x, val#x] + +(3) Sort [codegen id : 1] +Input: [key#x, val#x] + +(4) SortAggregate +Input: [key#x, val#x] +Keys: [key#x] +Functions: [partial_min(val#x)] +Aggregate Attributes: [min#x] +Results: [key#x, min#x] + +(5) Exchange +Input: [key#x, min#x] + +(6) Sort [codegen id : 2] +Input: [key#x, min#x] + +(7) SortAggregate +Input: [key#x, min#x] +Keys: [key#x] +Functions: [min(val#x)] +Aggregate Attributes: [min(val#x)#x] +Results: [key#x, min(val#x)#x AS min(val)#x] + + -- !query DROP TABLE explain_temp1 -- !query schema From aa0d13683cdf9f38f04cc0e73dc8cf63eed29bf4 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Thu, 13 Feb 2020 02:31:48 +0800 Subject: [PATCH 027/185] [SPARK-30760][SQL] Port `millisToDays` and `daysToMillis` on Java 8 time API ### What changes were proposed in this pull request? In the PR, I propose to rewrite the `millisToDays` and `daysToMillis` of `DateTimeUtils` using Java 8 time API. I removed `getOffsetFromLocalMillis` from `DateTimeUtils` because it is a private methods, and is not used anymore in Spark SQL. ### Why are the changes needed? New implementation is based on Proleptic Gregorian calendar which has been already used by other date-time functions. This changes make `millisToDays` and `daysToMillis` consistent to rest Spark SQL API related to date & time operations. ### Does this PR introduce any user-facing change? Yes, this might effect behavior for old dates before 1582 year. ### How was this patch tested? By existing test suites `DateTimeUtilsSuite`, `DateFunctionsSuite`, DateExpressionsSuite`, `SQLQuerySuite` and `HiveResultSuite`. Closes #27494 from MaxGekk/millis-2-days-java8-api. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../expressions/datetimeExpressions.scala | 8 +-- .../sql/catalyst/util/DateTimeUtils.scala | 58 +++++-------------- .../catalyst/csv/UnivocityParserSuite.scala | 3 +- .../expressions/DateExpressionsSuite.scala | 19 +++--- .../catalyst/util/DateTimeUtilsSuite.scala | 34 ++++++----- .../spark/sql/execution/HiveResult.scala | 5 ++ .../sql-tests/results/postgreSQL/date.sql.out | 12 ++-- .../apache/spark/sql/SQLQueryTestSuite.scala | 1 + 8 files changed, 62 insertions(+), 78 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1f4c8c041f8bf..cf91489d8e6b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -135,7 +135,7 @@ case class CurrentBatchTimestamp( def toLiteral: Literal = dataType match { case _: TimestampType => Literal(DateTimeUtils.fromJavaTimestamp(new Timestamp(timestampMs)), TimestampType) - case _: DateType => Literal(DateTimeUtils.millisToDays(timestampMs, timeZone), DateType) + case _: DateType => Literal(DateTimeUtils.millisToDays(timestampMs, zoneId), DateType) } } @@ -1332,14 +1332,14 @@ case class MonthsBetween( override def nullSafeEval(t1: Any, t2: Any, roundOff: Any): Any = { DateTimeUtils.monthsBetween( - t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], timeZone) + t1.asInstanceOf[Long], t2.asInstanceOf[Long], roundOff.asInstanceOf[Boolean], zoneId) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) + val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (d1, d2, roundOff) => { - s"""$dtu.monthsBetween($d1, $d2, $roundOff, $tz)""" + s"""$dtu.monthsBetween($d1, $d2, $roundOff, $zid)""" }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index eeae0674166bc..5976bcbb52fd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -67,24 +67,22 @@ object DateTimeUtils { // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisUtc: Long): SQLDate = { - millisToDays(millisUtc, defaultTimeZone()) + millisToDays(millisUtc, defaultTimeZone().toZoneId) } - def millisToDays(millisUtc: Long, timeZone: TimeZone): SQLDate = { - // SPARK-6785: use Math.floorDiv so negative number of days (dates before 1970) - // will correctly work as input for function toJavaDate(Int) - val millisLocal = millisUtc + timeZone.getOffset(millisUtc) - Math.floorDiv(millisLocal, MILLIS_PER_DAY).toInt + def millisToDays(millisUtc: Long, zoneId: ZoneId): SQLDate = { + val instant = microsToInstant(Math.multiplyExact(millisUtc, MICROS_PER_MILLIS)) + localDateToDays(LocalDateTime.ofInstant(instant, zoneId).toLocalDate) } // reverse of millisToDays def daysToMillis(days: SQLDate): Long = { - daysToMillis(days, defaultTimeZone()) + daysToMillis(days, defaultTimeZone().toZoneId) } - def daysToMillis(days: SQLDate, timeZone: TimeZone): Long = { - val millisLocal = days.toLong * MILLIS_PER_DAY - millisLocal - getOffsetFromLocalMillis(millisLocal, timeZone) + def daysToMillis(days: SQLDate, zoneId: ZoneId): Long = { + val instant = daysToLocalDate(days).atStartOfDay(zoneId).toInstant + instantToMicros(instant) / MICROS_PER_MILLIS } // Converts Timestamp to string according to Hive TimestampWritable convention. @@ -589,11 +587,11 @@ object DateTimeUtils { time1: SQLTimestamp, time2: SQLTimestamp, roundOff: Boolean, - timeZone: TimeZone): Double = { + zoneId: ZoneId): Double = { val millis1 = MICROSECONDS.toMillis(time1) val millis2 = MICROSECONDS.toMillis(time2) - val date1 = millisToDays(millis1, timeZone) - val date2 = millisToDays(millis2, timeZone) + val date1 = millisToDays(millis1, zoneId) + val date2 = millisToDays(millis2, zoneId) val (year1, monthInYear1, dayInMonth1, daysToMonthEnd1) = splitDate(date1) val (year2, monthInYear2, dayInMonth2, daysToMonthEnd2) = splitDate(date2) @@ -607,8 +605,8 @@ object DateTimeUtils { } // using milliseconds can cause precision loss with more than 8 digits // we follow Hive's implementation which uses seconds - val secondsInDay1 = MILLISECONDS.toSeconds(millis1 - daysToMillis(date1, timeZone)) - val secondsInDay2 = MILLISECONDS.toSeconds(millis2 - daysToMillis(date2, timeZone)) + val secondsInDay1 = MILLISECONDS.toSeconds(millis1 - daysToMillis(date1, zoneId)) + val secondsInDay2 = MILLISECONDS.toSeconds(millis2 - daysToMillis(date2, zoneId)) val secondsDiff = (dayInMonth1 - dayInMonth2) * SECONDS_PER_DAY + secondsInDay1 - secondsInDay2 val secondsInMonth = DAYS.toSeconds(31) val diff = monthDiff + secondsDiff / secondsInMonth.toDouble @@ -737,8 +735,8 @@ object DateTimeUtils { millis += offset millis - millis % MILLIS_PER_DAY - offset case _ => // Try to truncate date levels - val dDays = millisToDays(millis, timeZone) - daysToMillis(truncDate(dDays, level), timeZone) + val dDays = millisToDays(millis, timeZone.toZoneId) + daysToMillis(truncDate(dDays, level), timeZone.toZoneId) } truncated * MICROS_PER_MILLIS } @@ -770,32 +768,6 @@ object DateTimeUtils { } } - /** - * Lookup the offset for given millis seconds since 1970-01-01 00:00:00 in given timezone. - * TODO: Improve handling of normalization differences. - * TODO: Replace with JSR-310 or similar system - see SPARK-16788 - */ - private[sql] def getOffsetFromLocalMillis(millisLocal: Long, tz: TimeZone): Long = { - var guess = tz.getRawOffset - // the actual offset should be calculated based on milliseconds in UTC - val offset = tz.getOffset(millisLocal - guess) - if (offset != guess) { - guess = tz.getOffset(millisLocal - offset) - if (guess != offset) { - // fallback to do the reverse lookup using java.time.LocalDateTime - // this should only happen near the start or end of DST - val localDate = LocalDate.ofEpochDay(MILLISECONDS.toDays(millisLocal)) - val localTime = LocalTime.ofNanoOfDay(MILLISECONDS.toNanos( - Math.floorMod(millisLocal, MILLIS_PER_DAY))) - val localDateTime = LocalDateTime.of(localDate, localTime) - val millisEpoch = localDateTime.atZone(tz.toZoneId).toInstant.toEpochMilli - - guess = (millisLocal - millisEpoch).toInt - } - } - guess - } - /** * Convert the timestamp `ts` from one timezone to another. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 77a2ca7e4a828..536c76f042d23 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal import java.text.{DecimalFormat, DecimalFormatSymbols} +import java.time.ZoneOffset import java.util.{Locale, TimeZone} import org.apache.commons.lang3.time.FastDateFormat @@ -137,7 +138,7 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { val expectedDate = format.parse(customDate).getTime val castedDate = parser.makeConverter("_1", DateType, nullable = true) .apply(customDate) - assert(castedDate == DateTimeUtils.millisToDays(expectedDate, TimeZone.getTimeZone("GMT"))) + assert(castedDate == DateTimeUtils.millisToDays(expectedDate, ZoneOffset.UTC)) val timestamp = "2015-01-01 00:00:00" timestampsOptions = new CSVOptions(Map( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index f04149ab7eb29..39b859af47ca9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -56,9 +56,9 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val ts = new Timestamp(toMillis(time)) test("datetime function current_date") { - val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis(), TimeZoneGMT) + val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis(), ZoneOffset.UTC) val cd = CurrentDate(gmtId).eval(EmptyRow).asInstanceOf[Int] - val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis(), TimeZoneGMT) + val d1 = DateTimeUtils.millisToDays(System.currentTimeMillis(), ZoneOffset.UTC) assert(d0 <= cd && cd <= d1 && d1 - d0 <= 1) val cdjst = CurrentDate(jstId).eval(EmptyRow).asInstanceOf[Int] @@ -499,7 +499,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // Valid range of DateType is [0001-01-01, 9999-12-31] val maxMonthInterval = 10000 * 12 checkEvaluation( - AddMonths(Literal(Date.valueOf("0001-01-01")), Literal(maxMonthInterval)), 2933261) + AddMonths(Literal(LocalDate.parse("0001-01-01")), Literal(maxMonthInterval)), + LocalDate.of(10001, 1, 1).toEpochDay.toInt) checkEvaluation( AddMonths(Literal(Date.valueOf("9999-12-31")), Literal(-1 * maxMonthInterval)), -719529) // Test evaluation results between Interpreted mode and Codegen mode @@ -788,7 +789,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( UnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), MILLISECONDS.toSeconds( - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz.toZoneId))) checkEvaluation( UnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId), @@ -796,7 +797,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(UnixTimestamp( Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz.toZoneId))) val t1 = UnixTimestamp( CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] val t2 = UnixTimestamp( @@ -814,7 +815,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( UnixTimestamp(Literal(date1), Literal.create(null, StringType), timeZoneId), MILLISECONDS.toSeconds( - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz.toZoneId))) checkEvaluation( UnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) } @@ -852,7 +853,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation( ToUnixTimestamp(Literal(date1), Literal(fmt1), timeZoneId), MILLISECONDS.toSeconds( - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz.toZoneId))) checkEvaluation( ToUnixTimestamp( Literal(sdf2.format(new Timestamp(-1000000))), @@ -861,7 +862,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ToUnixTimestamp( Literal(sdf3.format(Date.valueOf("2015-07-24"))), Literal(fmt3), timeZoneId), MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( - DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) + DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz.toZoneId))) val t1 = ToUnixTimestamp( CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] val t2 = ToUnixTimestamp( @@ -876,7 +877,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ToUnixTimestamp( Literal(date1), Literal.create(null, StringType), timeZoneId), MILLISECONDS.toSeconds( - DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) + DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz.toZoneId))) checkEvaluation( ToUnixTimestamp( Literal("2015-07-24"), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index cabcd3007d1c0..cd0594c775a47 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -86,9 +86,13 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { } test("SPARK-6785: java date conversion before and after epoch") { + def format(d: Date): String = { + TimestampFormatter("uuuu-MM-dd", defaultTimeZone().toZoneId) + .format(d.getTime * MICROS_PER_MILLIS) + } def checkFromToJavaDate(d1: Date): Unit = { val d2 = toJavaDate(fromJavaDate(d1)) - assert(d2.toString === d1.toString) + assert(format(d2) === format(d1)) } val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) @@ -413,22 +417,22 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { test("monthsBetween") { val date1 = date(1997, 2, 28, 10, 30, 0) var date2 = date(1996, 10, 30) - assert(monthsBetween(date1, date2, true, TimeZoneUTC) === 3.94959677) - assert(monthsBetween(date1, date2, false, TimeZoneUTC) === 3.9495967741935485) + assert(monthsBetween(date1, date2, true, ZoneOffset.UTC) === 3.94959677) + assert(monthsBetween(date1, date2, false, ZoneOffset.UTC) === 3.9495967741935485) Seq(true, false).foreach { roundOff => date2 = date(2000, 2, 28) - assert(monthsBetween(date1, date2, roundOff, TimeZoneUTC) === -36) + assert(monthsBetween(date1, date2, roundOff, ZoneOffset.UTC) === -36) date2 = date(2000, 2, 29) - assert(monthsBetween(date1, date2, roundOff, TimeZoneUTC) === -36) + assert(monthsBetween(date1, date2, roundOff, ZoneOffset.UTC) === -36) date2 = date(1996, 3, 31) - assert(monthsBetween(date1, date2, roundOff, TimeZoneUTC) === 11) + assert(monthsBetween(date1, date2, roundOff, ZoneOffset.UTC) === 11) } val date3 = date(2000, 2, 28, 16, tz = TimeZonePST) val date4 = date(1997, 2, 28, 16, tz = TimeZonePST) - assert(monthsBetween(date3, date4, true, TimeZonePST) === 36.0) - assert(monthsBetween(date3, date4, true, TimeZoneGMT) === 35.90322581) - assert(monthsBetween(date3, date4, false, TimeZoneGMT) === 35.903225806451616) + assert(monthsBetween(date3, date4, true, TimeZonePST.toZoneId) === 36.0) + assert(monthsBetween(date3, date4, true, ZoneOffset.UTC) === 35.90322581) + assert(monthsBetween(date3, date4, false, ZoneOffset.UTC) === 35.903225806451616) } test("from UTC timestamp") { @@ -571,15 +575,15 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { test("daysToMillis and millisToDays") { val input = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, 16, tz = TimeZonePST)) - assert(millisToDays(input, TimeZonePST) === 16800) - assert(millisToDays(input, TimeZoneGMT) === 16801) - assert(millisToDays(-1 * MILLIS_PER_DAY + 1, TimeZoneGMT) == -1) + assert(millisToDays(input, TimeZonePST.toZoneId) === 16800) + assert(millisToDays(input, ZoneOffset.UTC) === 16801) + assert(millisToDays(-1 * MILLIS_PER_DAY + 1, ZoneOffset.UTC) == -1) var expected = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, tz = TimeZonePST)) - assert(daysToMillis(16800, TimeZonePST) === expected) + assert(daysToMillis(16800, TimeZonePST.toZoneId) === expected) expected = TimeUnit.MICROSECONDS.toMillis(date(2015, 12, 31, tz = TimeZoneGMT)) - assert(daysToMillis(16800, TimeZoneGMT) === expected) + assert(daysToMillis(16800, ZoneOffset.UTC) === expected) // There are some days are skipped entirely in some timezone, skip them here. val skipped_days = Map[String, Set[Int]]( @@ -594,7 +598,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { val skipped = skipped_days.getOrElse(tz.getID, Set.empty) (-20000 to 20000).foreach { d => if (!skipped.contains(d)) { - assert(millisToDays(daysToMillis(d, tz), tz) === d, + assert(millisToDays(daysToMillis(d, tz.toZoneId), tz.toZoneId) === d, s"Round trip of ${d} did not work in tz ${tz}") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index bbe47a63f4d61..5a2f16d8e1526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} @@ -67,8 +68,12 @@ object HiveResult { case (null, _) => if (nested) "null" else "NULL" case (b, BooleanType) => b.toString case (d: Date, DateType) => dateFormatter.format(DateTimeUtils.fromJavaDate(d)) + case (ld: LocalDate, DateType) => + dateFormatter.format(DateTimeUtils.localDateToDays(ld)) case (t: Timestamp, TimestampType) => timestampFormatter.format(DateTimeUtils.fromJavaTimestamp(t)) + case (i: Instant, TimestampType) => + timestampFormatter.format(DateTimeUtils.instantToMicros(i)) case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => decimal.toPlainString case (n, _: NumericType) => n.toString diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out index fd5dc42632176..ed27317121623 100755 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/date.sql.out @@ -800,7 +800,7 @@ SELECT DATE_TRUNC('MILLENNIUM', TIMESTAMP '1970-03-20 04:30:00.00000') -- !query schema struct -- !query output -1001-01-01 00:07:02 +1001-01-01 00:00:00 -- !query @@ -808,7 +808,7 @@ SELECT DATE_TRUNC('MILLENNIUM', DATE '1970-03-20') -- !query schema struct -- !query output -1001-01-01 00:07:02 +1001-01-01 00:00:00 -- !query @@ -840,7 +840,7 @@ SELECT DATE_TRUNC('CENTURY', DATE '0002-02-04') -- !query schema struct -- !query output -0001-01-01 00:07:02 +0001-01-01 00:00:00 -- !query @@ -848,7 +848,7 @@ SELECT DATE_TRUNC('CENTURY', TO_DATE('0055-08-10 BC', 'yyyy-MM-dd G')) -- !query schema struct -- !query output --0099-01-01 00:07:02 +-0099-01-01 00:00:00 -- !query @@ -864,7 +864,7 @@ SELECT DATE_TRUNC('DECADE', DATE '0004-12-25') -- !query schema struct -- !query output -0000-01-01 00:07:02 +0000-01-01 00:00:00 -- !query @@ -872,7 +872,7 @@ SELECT DATE_TRUNC('DECADE', TO_DATE('0002-12-31 BC', 'yyyy-MM-dd G')) -- !query schema struct -- !query output --0010-01-01 00:07:02 +-0010-01-01 00:00:00 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 2e5a9e0b4d45d..6b9e5bbd3c961 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -337,6 +337,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, true) case _ => } + localSparkSession.conf.set(SQLConf.DATETIME_JAVA8API_ENABLED.key, true) if (configSet.nonEmpty) { // Execute the list of set operation in order to add the desired configs From 5b76367a9d0aaca53ce96ab7e555a596567e8335 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 12 Feb 2020 14:27:18 -0800 Subject: [PATCH 028/185] [SPARK-30797][SQL] Set tradition user/group/other permission to ACL entries when setting up ACLs in truncate table ### What changes were proposed in this pull request? This is a follow-up to the PR #26956. In #26956, the patch proposed to preserve path permission when truncating table. When setting up original ACLs, we need to set user/group/other permission as ACL entries too, otherwise if the path doesn't have default user/group/other ACL entries, ACL API will complain an error `Invalid ACL: the user, group and other entries are required.`. In short this change makes sure: 1. Permissions for user/group/other are always kept into ACLs to work with ACL API. 2. Other custom ACLs are still kept after TRUNCATE TABLE (#26956 did this). ### Why are the changes needed? Without this fix, `TRUNCATE TABLE` will get an error when setting up ACLs if there is no default default user/group/other ACL entries. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Update unit test. Manual test on dev Spark cluster. Set ACLs for a table path without default user/group/other ACL entries: ``` hdfs dfs -setfacl --set 'user:liangchi:rwx,user::rwx,group::r--,other::r--' /user/hive/warehouse/test.db/test_truncate_table hdfs dfs -getfacl /user/hive/warehouse/test.db/test_truncate_table # file: /user/hive/warehouse/test.db/test_truncate_table # owner: liangchi # group: supergroup user::rwx user:liangchi:rwx group::r-- mask::rwx other::r-- ``` Then run `sql("truncate table test.test_truncate_table")`, it works by normally truncating the table and preserve ACLs. Closes #27548 from viirya/fix-truncate-table-permission. Lead-authored-by: Liang-Chi Hsieh Co-authored-by: Liang-Chi Hsieh Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/command/tables.scala | 32 +++++++++++++++++-- .../sql/execution/command/DDLSuite.scala | 21 +++++++++++- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 90dbdf5515d4d..61500b773cd7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql.execution.command import java.net.{URI, URISyntaxException} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.util.Try import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileContext, FsConstants, Path} -import org.apache.hadoop.fs.permission.{AclEntry, FsPermission} +import org.apache.hadoop.fs.permission.{AclEntry, AclEntryScope, AclEntryType, FsAction, FsPermission} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier @@ -538,12 +539,27 @@ case class TruncateTableCommand( } } optAcls.foreach { acls => + val aclEntries = acls.asScala.filter(_.getName != null).asJava + + // If the path doesn't have default ACLs, `setAcl` API will throw an error + // as it expects user/group/other permissions must be in ACL entries. + // So we need to add tradition user/group/other permission + // in the form of ACL. + optPermission.map { permission => + aclEntries.add(newAclEntry(AclEntryScope.ACCESS, + AclEntryType.USER, permission.getUserAction())) + aclEntries.add(newAclEntry(AclEntryScope.ACCESS, + AclEntryType.GROUP, permission.getGroupAction())) + aclEntries.add(newAclEntry(AclEntryScope.ACCESS, + AclEntryType.OTHER, permission.getOtherAction())) + } + try { - fs.setAcl(path, acls) + fs.setAcl(path, aclEntries) } catch { case NonFatal(e) => throw new SecurityException( - s"Failed to set original ACL $acls back to " + + s"Failed to set original ACL $aclEntries back to " + s"the created path: $path. Exception: ${e.getMessage}") } } @@ -574,6 +590,16 @@ case class TruncateTableCommand( } Seq.empty[Row] } + + private def newAclEntry( + scope: AclEntryScope, + aclType: AclEntryType, + permission: FsAction): AclEntry = { + new AclEntry.Builder() + .setScope(scope) + .setType(aclType) + .setPermission(permission).build() + } } abstract class DescribeCommandBase extends RunnableCommand { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 31e00781ae6b4..dbf4b09403423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -2042,6 +2042,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // Set ACL to table path. val customAcl = new java.util.ArrayList[AclEntry]() customAcl.add(new AclEntry.Builder() + .setName("test") .setType(AclEntryType.USER) .setScope(AclEntryScope.ACCESS) .setPermission(FsAction.READ).build()) @@ -2061,8 +2062,26 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { if (ignore) { assert(aclEntries.size() == 0) } else { - assert(aclEntries.size() == 1) + assert(aclEntries.size() == 4) assert(aclEntries.get(0) == customAcl.get(0)) + + // Setting ACLs will also set user/group/other permissions + // as ACL entries. + val user = new AclEntry.Builder() + .setType(AclEntryType.USER) + .setScope(AclEntryScope.ACCESS) + .setPermission(FsAction.ALL).build() + val group = new AclEntry.Builder() + .setType(AclEntryType.GROUP) + .setScope(AclEntryScope.ACCESS) + .setPermission(FsAction.ALL).build() + val other = new AclEntry.Builder() + .setType(AclEntryType.OTHER) + .setScope(AclEntryScope.ACCESS) + .setPermission(FsAction.ALL).build() + assert(aclEntries.get(1) == user) + assert(aclEntries.get(2) == group) + assert(aclEntries.get(3) == other) } } } From 496f6ac86001d284cbfb7488a63dd3a168919c0f Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Wed, 12 Feb 2020 16:45:42 -0600 Subject: [PATCH 029/185] [SPARK-29148][CORE] Add stage level scheduling dynamic allocation and scheduler backend changes ### What changes were proposed in this pull request? This is another PR for stage level scheduling. In particular this adds changes to the dynamic allocation manager and the scheduler backend to be able to track what executors are needed per ResourceProfile. Note the api is still private to Spark until the entire feature gets in, so this functionality will be there but only usable by tests for profiles other then the DefaultProfile. The main changes here are simply tracking things on a ResourceProfile basis as well as sending the executor requests to the scheduler backend for all ResourceProfiles. I introduce a ResourceProfileManager in this PR that will track all the actual ResourceProfile objects so that we can keep them all in a single place and just pass around and use in datastructures the resource profile id. The resource profile id can be used with the ResourceProfileManager to get the actual ResourceProfile contents. There are various places in the code that use executor "slots" for things. The ResourceProfile adds functionality to keep that calculation in it. This logic is more complex then it should due to standalone mode and mesos coarse grained not setting the executor cores config. They default to all cores on the worker, so calculating slots is harder there. This PR keeps the functionality to make the cores the limiting resource because the scheduler still uses that for "slots" for a few things. This PR does also add the resource profile id to the Stage and stage info classes to be able to test things easier. That full set of changes will come with the scheduler PR that will be after this one. The PR stops at the scheduler backend pieces for the cluster manager and the real YARN support hasn't been added in this PR, that again will be in a separate PR, so this has a few of the API changes up to the cluster manager and then just uses the default profile requests to continue. The code for the entire feature is here for reference: https://github.com/apache/spark/pull/27053/files although it needs to be upmerged again as well. ### Why are the changes needed? Needed for stage level scheduling feature. ### Does this PR introduce any user-facing change? No user facing api changes added here. ### How was this patch tested? Lots of unit tests and manually testing. I tested on yarn, k8s, standalone, local modes. Ran both failure and success cases. Closes #27313 from tgravescs/SPARK-29148. Authored-by: Thomas Graves Signed-off-by: Thomas Graves --- .../spark/ExecutorAllocationClient.scala | 31 +- .../spark/ExecutorAllocationManager.scala | 473 +++++--- .../scala/org/apache/spark/SparkContext.scala | 150 +-- .../apache/spark/internal/config/Tests.scala | 9 + .../resource/ExecutorResourceRequests.scala | 2 +- .../spark/resource/ResourceProfile.scala | 150 ++- .../resource/ResourceProfileBuilder.scala | 2 +- .../resource/ResourceProfileManager.scala | 86 ++ .../apache/spark/resource/ResourceUtils.scala | 109 +- .../apache/spark/scheduler/DAGScheduler.scala | 9 +- .../apache/spark/scheduler/ResultStage.scala | 5 +- .../spark/scheduler/ShuffleMapStage.scala | 5 +- .../org/apache/spark/scheduler/Stage.scala | 9 +- .../apache/spark/scheduler/StageInfo.scala | 9 +- .../spark/scheduler/TaskSchedulerImpl.scala | 4 +- .../CoarseGrainedSchedulerBackend.scala | 150 +-- .../cluster/StandaloneSchedulerBackend.scala | 11 +- .../scheduler/dynalloc/ExecutorMonitor.scala | 11 +- .../org/apache/spark/util/JsonProtocol.scala | 8 +- .../ExecutorAllocationManagerSuite.scala | 1049 +++++++++++------ .../apache/spark/HeartbeatReceiverSuite.scala | 9 +- .../org/apache/spark/LocalSparkContext.scala | 2 +- .../org/apache/spark/SparkContextSuite.scala | 36 +- .../BasicEventFilterBuilderSuite.scala | 4 +- .../ResourceProfileManagerSuite.scala | 103 ++ .../spark/resource/ResourceProfileSuite.scala | 79 +- .../spark/resource/ResourceUtilsSuite.scala | 3 + .../CoarseGrainedSchedulerBackendSuite.scala | 13 +- .../scheduler/EventLoggingListenerSuite.scala | 7 +- .../dynalloc/ExecutorMonitorSuite.scala | 19 +- .../spark/status/AppStatusListenerSuite.scala | 76 +- .../status/ListenerEventsTestHelper.scala | 10 +- .../org/apache/spark/ui/StagePageSuite.scala | 4 +- .../apache/spark/util/JsonProtocolSuite.scala | 15 +- python/pyspark/tests/test_context.py | 5 + python/pyspark/tests/test_taskcontext.py | 6 + .../KubernetesClusterSchedulerBackend.scala | 8 +- ...bernetesClusterSchedulerBackendSuite.scala | 4 + .../MesosCoarseGrainedSchedulerBackend.scala | 18 +- ...osCoarseGrainedSchedulerBackendSuite.scala | 26 +- .../cluster/YarnSchedulerBackend.scala | 25 +- .../cluster/YarnSchedulerBackendSuite.scala | 7 +- .../ui/MetricsAggregationBenchmark.scala | 4 +- .../ui/SQLAppStatusListenerSuite.scala | 4 +- .../scheduler/ExecutorAllocationManager.scala | 7 +- .../ExecutorAllocationManagerSuite.scala | 19 +- 46 files changed, 1935 insertions(+), 860 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala create mode 100644 core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index cb965cb180207..00bd0063c9e3a 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -37,24 +37,29 @@ private[spark] trait ExecutorAllocationClient { /** * Update the cluster manager on our scheduling needs. Three bits of information are included * to help it make decisions. - * @param numExecutors The total number of executors we'd like to have. The cluster manager - * shouldn't kill any running executor to reach this number, but, - * if all existing executors were to die, this is the number of executors - * we'd want to be allocated. - * @param localityAwareTasks The number of tasks in all active stages that have a locality - * preferences. This includes running, pending, and completed tasks. - * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages - * that would like to like to run on that host. - * This includes running, pending, and completed tasks. + * + * @param resourceProfileIdToNumExecutors The total number of executors we'd like to have per + * ResourceProfile id. The cluster manager shouldn't kill + * any running executor to reach this number, but, if all + * existing executors were to die, this is the number + * of executors we'd want to be allocated. + * @param numLocalityAwareTasksPerResourceProfileId The number of tasks in all active stages that + * have a locality preferences per + * ResourceProfile id. This includes running, + * pending, and completed tasks. + * @param hostToLocalTaskCount A map of ResourceProfile id to a map of hosts to the number of + * tasks from all active stages that would like to like to run on + * that host. This includes running, pending, and completed tasks. * @return whether the request is acknowledged by the cluster manager. */ private[spark] def requestTotalExecutors( - numExecutors: Int, - localityAwareTasks: Int, - hostToLocalTaskCount: Map[String, Int]): Boolean + resourceProfileIdToNumExecutors: Map[Int, Int], + numLocalityAwareTasksPerResourceProfileId: Map[Int, Int], + hostToLocalTaskCount: Map[Int, Map[String, Int]]): Boolean /** - * Request an additional number of executors from the cluster manager. + * Request an additional number of executors from the cluster manager for the default + * ResourceProfile. * @return whether the request is acknowledged by the cluster manager. */ def requestExecutors(numAdditionalExecutors: Int): Boolean diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 677386cc7a572..5cb3160711a90 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -29,6 +29,8 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Tests.TEST_SCHEDULE_INTERVAL import org.apache.spark.metrics.source.Source +import org.apache.spark.resource.ResourceProfile.UNKNOWN_RESOURCE_PROFILE_ID +import org.apache.spark.resource.ResourceProfileManager import org.apache.spark.scheduler._ import org.apache.spark.scheduler.dynalloc.ExecutorMonitor import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} @@ -36,9 +38,9 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. * - * The ExecutorAllocationManager maintains a moving target number of executors which is periodically - * synced to the cluster manager. The target starts at a configured initial value and changes with - * the number of pending and running tasks. + * The ExecutorAllocationManager maintains a moving target number of executors, for each + * ResourceProfile, which is periodically synced to the cluster manager. The target starts + * at a configured initial value and changes with the number of pending and running tasks. * * Decreasing the target number of executors happens when the current target is more than needed to * handle the current load. The target number of executors is always truncated to the number of @@ -57,14 +59,18 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * quickly over time in case the maximum number of executors is very high. Otherwise, it will take * a long time to ramp up under heavy workloads. * - * The remove policy is simpler: If an executor has been idle for K seconds, meaning it has not - * been scheduled to run any tasks, then it is removed. Note that an executor caching any data + * The remove policy is simpler and is applied on each ResourceProfile separately. If an executor + * for that ResourceProfile has been idle for K seconds and the number of executors is more + * then what is needed for that ResourceProfile, meaning there are not enough tasks that could use + * the executor, then it is removed. Note that an executor caching any data * blocks will be removed if it has been idle for more than L seconds. * * There is no retry logic in either case because we make the assumption that the cluster manager * will eventually fulfill all requests it receives asynchronously. * - * The relevant Spark properties include the following: + * The relevant Spark properties are below. Each of these properties applies separately to + * every ResourceProfile. So if you set a minimum number of executors, that is a minimum + * for each ResourceProfile. * * spark.dynamicAllocation.enabled - Whether this feature is enabled * spark.dynamicAllocation.minExecutors - Lower bound on the number of executors @@ -95,7 +101,8 @@ private[spark] class ExecutorAllocationManager( listenerBus: LiveListenerBus, conf: SparkConf, cleaner: Option[ContextCleaner] = None, - clock: Clock = new SystemClock()) + clock: Clock = new SystemClock(), + resourceProfileManager: ResourceProfileManager) extends Logging { allocationManager => @@ -117,23 +124,23 @@ private[spark] class ExecutorAllocationManager( // During testing, the methods to actually kill and add executors are mocked out private val testing = conf.get(DYN_ALLOCATION_TESTING) - // TODO: The default value of 1 for spark.executor.cores works right now because dynamic - // allocation is only supported for YARN and the default number of cores per executor in YARN is - // 1, but it might need to be attained differently for different cluster managers - private val tasksPerExecutorForFullParallelism = - conf.get(EXECUTOR_CORES) / conf.get(CPUS_PER_TASK) - private val executorAllocationRatio = conf.get(DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO) + private val defaultProfileId = resourceProfileManager.defaultResourceProfile.id + validateSettings() - // Number of executors to add in the next round - private var numExecutorsToAdd = 1 + // Number of executors to add for each ResourceProfile in the next round + private val numExecutorsToAddPerResourceProfileId = new mutable.HashMap[Int, Int] + numExecutorsToAddPerResourceProfileId(defaultProfileId) = 1 // The desired number of executors at this moment in time. If all our executors were to die, this // is the number of executors we would immediately want from the cluster manager. - private var numExecutorsTarget = initialNumExecutors + // Note every profile will be allowed to have initial number, + // we may want to make this configurable per Profile in the future + private val numExecutorsTargetPerResourceProfileId = new mutable.HashMap[Int, Int] + numExecutorsTargetPerResourceProfileId(defaultProfileId) = initialNumExecutors // A timestamp of when an addition should be triggered, or NOT_SET if it is not set // This is set when pending tasks are added but not scheduled yet @@ -165,11 +172,12 @@ private[spark] class ExecutorAllocationManager( // (2) an executor idle timeout has elapsed. @volatile private var initializing: Boolean = true - // Number of locality aware tasks, used for executor placement. - private var localityAwareTasks = 0 + // Number of locality aware tasks for each ResourceProfile, used for executor placement. + private var numLocalityAwareTasksPerResourceProfileId = new mutable.HashMap[Int, Int] + numLocalityAwareTasksPerResourceProfileId(defaultProfileId) = 0 - // Host to possible task running on it, used for executor placement. - private var hostToLocalTaskCount: Map[String, Int] = Map.empty + // ResourceProfile id to Host to possible task running on it, used for executor placement. + private var rpIdToHostToLocalTaskCount: Map[Int, Map[String, Int]] = Map.empty /** * Verify that the settings specified through the config are valid. @@ -233,7 +241,14 @@ private[spark] class ExecutorAllocationManager( } executor.scheduleWithFixedDelay(scheduleTask, 0, intervalMillis, TimeUnit.MILLISECONDS) - client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) + // copy the maps inside synchonize to ensure not being modified + val (numExecutorsTarget, numLocalityAware) = synchronized { + val numTarget = numExecutorsTargetPerResourceProfileId.toMap + val numLocality = numLocalityAwareTasksPerResourceProfileId.toMap + (numTarget, numLocality) + } + + client.requestTotalExecutors(numExecutorsTarget, numLocalityAware, rpIdToHostToLocalTaskCount) } /** @@ -253,20 +268,28 @@ private[spark] class ExecutorAllocationManager( */ def reset(): Unit = synchronized { addTime = 0L - numExecutorsTarget = initialNumExecutors + numExecutorsTargetPerResourceProfileId.keys.foreach { rpId => + numExecutorsTargetPerResourceProfileId(rpId) = initialNumExecutors + } executorMonitor.reset() } /** - * The maximum number of executors we would need under the current load to satisfy all running - * and pending tasks, rounded up. + * The maximum number of executors, for the ResourceProfile id passed in, that we would need + * under the current load to satisfy all running and pending tasks, rounded up. */ - private def maxNumExecutorsNeeded(): Int = { - val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks + private def maxNumExecutorsNeededPerResourceProfile(rpId: Int): Int = { + val pending = listener.totalPendingTasksPerResourceProfile(rpId) + val pendingSpeculative = listener.pendingSpeculativeTasksPerResourceProfile(rpId) + val running = listener.totalRunningTasksPerResourceProfile(rpId) + val numRunningOrPendingTasks = pending + running + val rp = resourceProfileManager.resourceProfileFromId(rpId) + val tasksPerExecutor = rp.maxTasksPerExecutor(conf) + logDebug(s"max needed for rpId: $rpId numpending: $numRunningOrPendingTasks," + + s" tasksperexecutor: $tasksPerExecutor") val maxNeeded = math.ceil(numRunningOrPendingTasks * executorAllocationRatio / - tasksPerExecutorForFullParallelism).toInt - if (tasksPerExecutorForFullParallelism > 1 && maxNeeded == 1 && - listener.pendingSpeculativeTasks > 0) { + tasksPerExecutor).toInt + if (tasksPerExecutor > 1 && maxNeeded == 1 && pendingSpeculative > 0) { // If we have pending speculative tasks and only need a single executor, allocate one more // to satisfy the locality requirements of speculation maxNeeded + 1 @@ -275,8 +298,8 @@ private[spark] class ExecutorAllocationManager( } } - private def totalRunningTasks(): Int = synchronized { - listener.totalRunningTasks + private def totalRunningTasksPerResourceProfile(id: Int): Int = synchronized { + listener.totalRunningTasksPerResourceProfile(id) } /** @@ -302,7 +325,8 @@ private[spark] class ExecutorAllocationManager( } /** - * Updates our target number of executors and syncs the result with the cluster manager. + * Updates our target number of executors for each ResourceProfile and then syncs the result + * with the cluster manager. * * Check to see whether our existing allocation and the requests we've made previously exceed our * current needs. If so, truncate our target and let the cluster manager know so that it can @@ -314,130 +338,205 @@ private[spark] class ExecutorAllocationManager( * @return the delta in the target number of executors. */ private def updateAndSyncNumExecutorsTarget(now: Long): Int = synchronized { - val maxNeeded = maxNumExecutorsNeeded - if (initializing) { // Do not change our target while we are still initializing, // Otherwise the first job may have to ramp up unnecessarily 0 - } else if (maxNeeded < numExecutorsTarget) { - // The target number exceeds the number we actually need, so stop adding new - // executors and inform the cluster manager to cancel the extra pending requests - val oldNumExecutorsTarget = numExecutorsTarget - numExecutorsTarget = math.max(maxNeeded, minNumExecutors) - numExecutorsToAdd = 1 - - // If the new target has not changed, avoid sending a message to the cluster manager - if (numExecutorsTarget < oldNumExecutorsTarget) { - // We lower the target number of executors but don't actively kill any yet. Killing is - // controlled separately by an idle timeout. It's still helpful to reduce the target number - // in case an executor just happens to get lost (eg., bad hardware, or the cluster manager - // preempts it) -- in that case, there is no point in trying to immediately get a new - // executor, since we wouldn't even use it yet. - client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) - logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " + - s"$oldNumExecutorsTarget) because not all requested executors are actually needed") + } else { + val updatesNeeded = new mutable.HashMap[Int, ExecutorAllocationManager.TargetNumUpdates] + + // Update targets for all ResourceProfiles then do a single request to the cluster manager + numExecutorsTargetPerResourceProfileId.foreach { case (rpId, targetExecs) => + val maxNeeded = maxNumExecutorsNeededPerResourceProfile(rpId) + if (maxNeeded < targetExecs) { + // The target number exceeds the number we actually need, so stop adding new + // executors and inform the cluster manager to cancel the extra pending requests + + // We lower the target number of executors but don't actively kill any yet. Killing is + // controlled separately by an idle timeout. It's still helpful to reduce + // the target number in case an executor just happens to get lost (eg., bad hardware, + // or the cluster manager preempts it) -- in that case, there is no point in trying + // to immediately get a new executor, since we wouldn't even use it yet. + decrementExecutorsFromTarget(maxNeeded, rpId, updatesNeeded) + } else if (addTime != NOT_SET && now >= addTime) { + addExecutorsToTarget(maxNeeded, rpId, updatesNeeded) + } + } + doUpdateRequest(updatesNeeded.toMap, now) + } + } + + private def addExecutorsToTarget( + maxNeeded: Int, + rpId: Int, + updatesNeeded: mutable.HashMap[Int, ExecutorAllocationManager.TargetNumUpdates]): Int = { + updateTargetExecs(addExecutors, maxNeeded, rpId, updatesNeeded) + } + + private def decrementExecutorsFromTarget( + maxNeeded: Int, + rpId: Int, + updatesNeeded: mutable.HashMap[Int, ExecutorAllocationManager.TargetNumUpdates]): Int = { + updateTargetExecs(decrementExecutors, maxNeeded, rpId, updatesNeeded) + } + + private def updateTargetExecs( + updateTargetFn: (Int, Int) => Int, + maxNeeded: Int, + rpId: Int, + updatesNeeded: mutable.HashMap[Int, ExecutorAllocationManager.TargetNumUpdates]): Int = { + val oldNumExecutorsTarget = numExecutorsTargetPerResourceProfileId(rpId) + // update the target number (add or remove) + val delta = updateTargetFn(maxNeeded, rpId) + if (delta != 0) { + updatesNeeded(rpId) = ExecutorAllocationManager.TargetNumUpdates(delta, oldNumExecutorsTarget) + } + delta + } + + private def doUpdateRequest( + updates: Map[Int, ExecutorAllocationManager.TargetNumUpdates], + now: Long): Int = { + // Only call cluster manager if target has changed. + if (updates.size > 0) { + val requestAcknowledged = try { + logDebug("requesting updates: " + updates) + testing || + client.requestTotalExecutors( + numExecutorsTargetPerResourceProfileId.toMap, + numLocalityAwareTasksPerResourceProfileId.toMap, + rpIdToHostToLocalTaskCount) + } catch { + case NonFatal(e) => + // Use INFO level so the error it doesn't show up by default in shells. + // Errors here are more commonly caused by YARN AM restarts, which is a recoverable + // issue, and generate a lot of noisy output. + logInfo("Error reaching cluster manager.", e) + false + } + if (requestAcknowledged) { + // have to go through all resource profiles that changed + var totalDelta = 0 + updates.foreach { case (rpId, targetNum) => + val delta = targetNum.delta + totalDelta += delta + if (delta > 0) { + val executorsString = "executor" + { if (delta > 1) "s" else "" } + logInfo(s"Requesting $delta new $executorsString because tasks are backlogged " + + s"(new desired total will be ${numExecutorsTargetPerResourceProfileId(rpId)} " + + s"for resource profile id: ${rpId})") + numExecutorsToAddPerResourceProfileId(rpId) = + if (delta == numExecutorsToAddPerResourceProfileId(rpId)) { + numExecutorsToAddPerResourceProfileId(rpId) * 2 + } else { + 1 + } + logDebug(s"Starting timer to add more executors (to " + + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") + addTime = now + TimeUnit.SECONDS.toNanos(sustainedSchedulerBacklogTimeoutS) + } else { + logDebug(s"Lowering target number of executors to" + + s" ${numExecutorsTargetPerResourceProfileId(rpId)} (previously " + + s"$targetNum.oldNumExecutorsTarget for resource profile id: ${rpId}) " + + "because not all requested executors " + + "are actually needed") + } + } + totalDelta + } else { + // request was for all profiles so we have to go through all to reset to old num + updates.foreach { case (rpId, targetNum) => + logWarning("Unable to reach the cluster manager to request more executors!") + numExecutorsTargetPerResourceProfileId(rpId) = targetNum.oldNumExecutorsTarget + } + 0 } - numExecutorsTarget - oldNumExecutorsTarget - } else if (addTime != NOT_SET && now >= addTime) { - val delta = addExecutors(maxNeeded) - logDebug(s"Starting timer to add more executors (to " + - s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") - addTime = now + TimeUnit.SECONDS.toNanos(sustainedSchedulerBacklogTimeoutS) - delta } else { + logDebug("No change in number of executors") 0 } } + private def decrementExecutors(maxNeeded: Int, rpId: Int): Int = { + val oldNumExecutorsTarget = numExecutorsTargetPerResourceProfileId(rpId) + numExecutorsTargetPerResourceProfileId(rpId) = math.max(maxNeeded, minNumExecutors) + numExecutorsToAddPerResourceProfileId(rpId) = 1 + numExecutorsTargetPerResourceProfileId(rpId) - oldNumExecutorsTarget + } + /** - * Request a number of executors from the cluster manager. + * Update the target number of executors and figure out how many to add. * If the cap on the number of executors is reached, give up and reset the * number of executors to add next round instead of continuing to double it. * * @param maxNumExecutorsNeeded the maximum number of executors all currently running or pending * tasks could fill + * @param rpId the ResourceProfile id of the executors * @return the number of additional executors actually requested. */ - private def addExecutors(maxNumExecutorsNeeded: Int): Int = { + private def addExecutors(maxNumExecutorsNeeded: Int, rpId: Int): Int = { + val oldNumExecutorsTarget = numExecutorsTargetPerResourceProfileId(rpId) // Do not request more executors if it would put our target over the upper bound - if (numExecutorsTarget >= maxNumExecutors) { - logDebug(s"Not adding executors because our current target total " + - s"is already $numExecutorsTarget (limit $maxNumExecutors)") - numExecutorsToAdd = 1 + // this is doing a max check per ResourceProfile + if (oldNumExecutorsTarget >= maxNumExecutors) { + logDebug("Not adding executors because our current target total " + + s"is already ${oldNumExecutorsTarget} (limit $maxNumExecutors)") + numExecutorsToAddPerResourceProfileId(rpId) = 1 return 0 } - - val oldNumExecutorsTarget = numExecutorsTarget // There's no point in wasting time ramping up to the number of executors we already have, so // make sure our target is at least as much as our current allocation: - numExecutorsTarget = math.max(numExecutorsTarget, executorMonitor.executorCount) + var numExecutorsTarget = math.max(numExecutorsTargetPerResourceProfileId(rpId), + executorMonitor.executorCountWithResourceProfile(rpId)) // Boost our target with the number to add for this round: - numExecutorsTarget += numExecutorsToAdd + numExecutorsTarget += numExecutorsToAddPerResourceProfileId(rpId) // Ensure that our target doesn't exceed what we need at the present moment: numExecutorsTarget = math.min(numExecutorsTarget, maxNumExecutorsNeeded) // Ensure that our target fits within configured bounds: numExecutorsTarget = math.max(math.min(numExecutorsTarget, maxNumExecutors), minNumExecutors) - val delta = numExecutorsTarget - oldNumExecutorsTarget + numExecutorsTargetPerResourceProfileId(rpId) = numExecutorsTarget // If our target has not changed, do not send a message // to the cluster manager and reset our exponential growth if (delta == 0) { - numExecutorsToAdd = 1 - return 0 - } - - val addRequestAcknowledged = try { - testing || - client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) - } catch { - case NonFatal(e) => - // Use INFO level so the error it doesn't show up by default in shells. Errors here are more - // commonly caused by YARN AM restarts, which is a recoverable issue, and generate a lot of - // noisy output. - logInfo("Error reaching cluster manager.", e) - false - } - if (addRequestAcknowledged) { - val executorsString = "executor" + { if (delta > 1) "s" else "" } - logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" + - s" (new desired total will be $numExecutorsTarget)") - numExecutorsToAdd = if (delta == numExecutorsToAdd) { - numExecutorsToAdd * 2 - } else { - 1 - } - delta - } else { - logWarning( - s"Unable to reach the cluster manager to request $numExecutorsTarget total executors!") - numExecutorsTarget = oldNumExecutorsTarget - 0 + numExecutorsToAddPerResourceProfileId(rpId) = 1 } + delta } /** * Request the cluster manager to remove the given executors. * Returns the list of executors which are removed. */ - private def removeExecutors(executors: Seq[String]): Seq[String] = synchronized { + private def removeExecutors(executors: Seq[(String, Int)]): Seq[String] = synchronized { val executorIdsToBeRemoved = new ArrayBuffer[String] - logDebug(s"Request to remove executorIds: ${executors.mkString(", ")}") - val numExistingExecutors = executorMonitor.executorCount - executorMonitor.pendingRemovalCount - - var newExecutorTotal = numExistingExecutors - executors.foreach { executorIdToBeRemoved => - if (newExecutorTotal - 1 < minNumExecutors) { - logDebug(s"Not removing idle executor $executorIdToBeRemoved because there are only " + - s"$newExecutorTotal executor(s) left (minimum number of executor limit $minNumExecutors)") - } else if (newExecutorTotal - 1 < numExecutorsTarget) { - logDebug(s"Not removing idle executor $executorIdToBeRemoved because there are only " + - s"$newExecutorTotal executor(s) left (number of executor target $numExecutorsTarget)") + val numExecutorsTotalPerRpId = mutable.Map[Int, Int]() + executors.foreach { case (executorIdToBeRemoved, rpId) => + if (rpId == UNKNOWN_RESOURCE_PROFILE_ID) { + if (testing) { + throw new SparkException("ResourceProfile Id was UNKNOWN, this is not expected") + } + logWarning(s"Not removing executor $executorIdsToBeRemoved because the " + + "ResourceProfile was UNKNOWN!") } else { - executorIdsToBeRemoved += executorIdToBeRemoved - newExecutorTotal -= 1 + // get the running total as we remove or initialize it to the count - pendingRemoval + val newExecutorTotal = numExecutorsTotalPerRpId.getOrElseUpdate(rpId, + (executorMonitor.executorCountWithResourceProfile(rpId) - + executorMonitor.pendingRemovalCountPerResourceProfileId(rpId))) + if (newExecutorTotal - 1 < minNumExecutors) { + logDebug(s"Not removing idle executor $executorIdToBeRemoved because there " + + s"are only $newExecutorTotal executor(s) left (minimum number of executor limit " + + s"$minNumExecutors)") + } else if (newExecutorTotal - 1 < numExecutorsTargetPerResourceProfileId(rpId)) { + logDebug(s"Not removing idle executor $executorIdToBeRemoved because there " + + s"are only $newExecutorTotal executor(s) left (number of executor " + + s"target ${numExecutorsTargetPerResourceProfileId(rpId)})") + } else { + executorIdsToBeRemoved += executorIdToBeRemoved + numExecutorsTotalPerRpId(rpId) -= 1 + } } } @@ -457,14 +556,15 @@ private[spark] class ExecutorAllocationManager( // [SPARK-21834] killExecutors api reduces the target number of executors. // So we need to update the target with desired value. - client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount) + client.requestTotalExecutors( + numExecutorsTargetPerResourceProfileId.toMap, + numLocalityAwareTasksPerResourceProfileId.toMap, + rpIdToHostToLocalTaskCount) + // reset the newExecutorTotal to the existing number of executors - newExecutorTotal = numExistingExecutors if (testing || executorsRemoved.nonEmpty) { - newExecutorTotal -= executorsRemoved.size executorMonitor.executorsKilled(executorsRemoved) - logInfo(s"Executors ${executorsRemoved.mkString(",")} removed due to idle timeout." + - s"(new desired total will be $newExecutorTotal)") + logInfo(s"Executors ${executorsRemoved.mkString(",")} removed due to idle timeout.") executorsRemoved } else { logWarning(s"Unable to reach the cluster manager to kill executor/s " + @@ -493,7 +593,7 @@ private[spark] class ExecutorAllocationManager( private def onSchedulerQueueEmpty(): Unit = synchronized { logDebug("Clearing timer to add executors because there are no more pending tasks") addTime = NOT_SET - numExecutorsToAdd = 1 + numExecutorsToAddPerResourceProfileId.transform { case (_, _) => 1 } } private case class StageAttempt(stageId: Int, stageAttemptId: Int) { @@ -519,12 +619,16 @@ private[spark] class ExecutorAllocationManager( private val stageAttemptToSpeculativeTaskIndices = new mutable.HashMap[StageAttempt, mutable.HashSet[Int]] + private val resourceProfileIdToStageAttempt = + new mutable.HashMap[Int, mutable.Set[StageAttempt]] + // stageAttempt to tuple (the number of task with locality preferences, a map where each pair - // is a node and the number of tasks that would like to be scheduled on that node) map, + // is a node and the number of tasks that would like to be scheduled on that node, and + // the resource profile id) map, // maintain the executor placement hints for each stageAttempt used by resource framework // to better place the executors. private val stageAttemptToExecutorPlacementHints = - new mutable.HashMap[StageAttempt, (Int, Map[String, Int])] + new mutable.HashMap[StageAttempt, (Int, Map[String, Int], Int)] override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { initializing = false @@ -535,6 +639,13 @@ private[spark] class ExecutorAllocationManager( allocationManager.synchronized { stageAttemptToNumTasks(stageAttempt) = numTasks allocationManager.onSchedulerBacklogged() + // need to keep stage task requirements to ask for the right containers + val profId = stageSubmitted.stageInfo.resourceProfileId + logDebug(s"Stage resource profile id is: $profId with numTasks: $numTasks") + resourceProfileIdToStageAttempt.getOrElseUpdate( + profId, new mutable.HashSet[StageAttempt]) += stageAttempt + numExecutorsToAddPerResourceProfileId.getOrElseUpdate(profId, 1) + numExecutorsTargetPerResourceProfileId.getOrElseUpdate(profId, initialNumExecutors) // Compute the number of tasks requested by the stage on each host var numTasksPending = 0 @@ -549,7 +660,7 @@ private[spark] class ExecutorAllocationManager( } } stageAttemptToExecutorPlacementHints.put(stageAttempt, - (numTasksPending, hostToLocalTaskCountPerStage.toMap)) + (numTasksPending, hostToLocalTaskCountPerStage.toMap, profId)) // Update the executor placement hints updateExecutorPlacementHints() @@ -561,7 +672,7 @@ private[spark] class ExecutorAllocationManager( val stageAttemptId = stageCompleted.stageInfo.attemptNumber() val stageAttempt = StageAttempt(stageId, stageAttemptId) allocationManager.synchronized { - // do NOT remove stageAttempt from stageAttemptToNumRunningTasks, + // do NOT remove stageAttempt from stageAttemptToNumRunningTask // because the attempt may still have running tasks, // even after another attempt for the stage is submitted. stageAttemptToNumTasks -= stageAttempt @@ -597,7 +708,7 @@ private[spark] class ExecutorAllocationManager( stageAttemptToTaskIndices.getOrElseUpdate(stageAttempt, new mutable.HashSet[Int]) += taskIndex } - if (totalPendingTasks() == 0) { + if (!hasPendingTasks) { allocationManager.onSchedulerQueueEmpty() } } @@ -613,9 +724,22 @@ private[spark] class ExecutorAllocationManager( stageAttemptToNumRunningTask(stageAttempt) -= 1 if (stageAttemptToNumRunningTask(stageAttempt) == 0) { stageAttemptToNumRunningTask -= stageAttempt + if (!stageAttemptToNumTasks.contains(stageAttempt)) { + val rpForStage = resourceProfileIdToStageAttempt.filter { case (k, v) => + v.contains(stageAttempt) + }.keys + if (rpForStage.size == 1) { + // be careful about the removal from here due to late tasks, make sure stage is + // really complete and no tasks left + resourceProfileIdToStageAttempt(rpForStage.head) -= stageAttempt + } else { + logWarning(s"Should have exactly one resource profile for stage $stageAttempt," + + s" but have $rpForStage") + } + } + } } - if (taskEnd.taskInfo.speculative) { stageAttemptToSpeculativeTaskIndices.get(stageAttempt).foreach {_.remove{taskIndex}} stageAttemptToNumSpeculativeTasks(stageAttempt) -= 1 @@ -624,7 +748,7 @@ private[spark] class ExecutorAllocationManager( taskEnd.reason match { case Success | _: TaskKilled => case _ => - if (totalPendingTasks() == 0) { + if (!hasPendingTasks) { // If the task failed (not intentionally killed), we expect it to be resubmitted // later. To ensure we have enough resources to run the resubmitted task, we need to // mark the scheduler as backlogged again if it's not already marked as such @@ -661,20 +785,46 @@ private[spark] class ExecutorAllocationManager( * * Note: This is not thread-safe without the caller owning the `allocationManager` lock. */ - def pendingTasks(): Int = { - stageAttemptToNumTasks.map { case (stageAttempt, numTasks) => - numTasks - stageAttemptToTaskIndices.get(stageAttempt).map(_.size).getOrElse(0) - }.sum + def pendingTasksPerResourceProfile(rpId: Int): Int = { + val attempts = resourceProfileIdToStageAttempt.getOrElse(rpId, Set.empty).toSeq + attempts.map(attempt => getPendingTaskSum(attempt)).sum } - def pendingSpeculativeTasks(): Int = { - stageAttemptToNumSpeculativeTasks.map { case (stageAttempt, numTasks) => - numTasks - stageAttemptToSpeculativeTaskIndices.get(stageAttempt).map(_.size).getOrElse(0) - }.sum + def hasPendingRegularTasks: Boolean = { + val attemptSets = resourceProfileIdToStageAttempt.values + attemptSets.exists(attempts => attempts.exists(getPendingTaskSum(_) > 0)) + } + + private def getPendingTaskSum(attempt: StageAttempt): Int = { + val numTotalTasks = stageAttemptToNumTasks.getOrElse(attempt, 0) + val numRunning = stageAttemptToTaskIndices.get(attempt).map(_.size).getOrElse(0) + numTotalTasks - numRunning } - def totalPendingTasks(): Int = { - pendingTasks + pendingSpeculativeTasks + def pendingSpeculativeTasksPerResourceProfile(rp: Int): Int = { + val attempts = resourceProfileIdToStageAttempt.getOrElse(rp, Set.empty).toSeq + attempts.map(attempt => getPendingSpeculativeTaskSum(attempt)).sum + } + + def hasPendingSpeculativeTasks: Boolean = { + val attemptSets = resourceProfileIdToStageAttempt.values + attemptSets.exists { attempts => + attempts.exists(getPendingSpeculativeTaskSum(_) > 0) + } + } + + private def getPendingSpeculativeTaskSum(attempt: StageAttempt): Int = { + val numTotalTasks = stageAttemptToNumSpeculativeTasks.getOrElse(attempt, 0) + val numRunning = stageAttemptToSpeculativeTaskIndices.get(attempt).map(_.size).getOrElse(0) + numTotalTasks - numRunning + } + + def hasPendingTasks: Boolean = { + hasPendingSpeculativeTasks || hasPendingRegularTasks + } + + def totalPendingTasksPerResourceProfile(rp: Int): Int = { + pendingTasksPerResourceProfile(rp) + pendingSpeculativeTasksPerResourceProfile(rp) } /** @@ -685,6 +835,14 @@ private[spark] class ExecutorAllocationManager( stageAttemptToNumRunningTask.values.sum } + def totalRunningTasksPerResourceProfile(rp: Int): Int = { + val attempts = resourceProfileIdToStageAttempt.getOrElse(rp, Set.empty).toSeq + // attempts is a Set, change to Seq so we keep all values + attempts.map { attempt => + stageAttemptToNumRunningTask.getOrElseUpdate(attempt, 0) + }.sum + } + /** * Update the Executor placement hints (the number of tasks with locality preferences, * a map where each pair is a node and the number of tasks that would like to be scheduled @@ -694,18 +852,27 @@ private[spark] class ExecutorAllocationManager( * granularity within stages. */ def updateExecutorPlacementHints(): Unit = { - var localityAwareTasks = 0 - val localityToCount = new mutable.HashMap[String, Int]() - stageAttemptToExecutorPlacementHints.values.foreach { case (numTasksPending, localities) => - localityAwareTasks += numTasksPending - localities.foreach { case (hostname, count) => - val updatedCount = localityToCount.getOrElse(hostname, 0) + count - localityToCount(hostname) = updatedCount - } + val localityAwareTasksPerResourceProfileId = new mutable.HashMap[Int, Int] + + // ResourceProfile id => map[host, count] + val rplocalityToCount = new mutable.HashMap[Int, mutable.HashMap[String, Int]]() + stageAttemptToExecutorPlacementHints.values.foreach { + case (numTasksPending, localities, rpId) => + val rpNumPending = + localityAwareTasksPerResourceProfileId.getOrElse(rpId, 0) + localityAwareTasksPerResourceProfileId(rpId) = rpNumPending + numTasksPending + localities.foreach { case (hostname, count) => + val rpBasedHostToCount = + rplocalityToCount.getOrElseUpdate(rpId, new mutable.HashMap[String, Int]) + val newUpdated = rpBasedHostToCount.getOrElse(hostname, 0) + count + rpBasedHostToCount(hostname) = newUpdated + } } - allocationManager.localityAwareTasks = localityAwareTasks - allocationManager.hostToLocalTaskCount = localityToCount.toMap + allocationManager.numLocalityAwareTasksPerResourceProfileId = + localityAwareTasksPerResourceProfileId + allocationManager.rpIdToHostToLocalTaskCount = + rplocalityToCount.map { case (k, v) => (k, v.toMap)}.toMap } } @@ -726,14 +893,22 @@ private[spark] class ExecutorAllocationManager( }) } - registerGauge("numberExecutorsToAdd", numExecutorsToAdd, 0) + // The metrics are going to return the sum for all the different ResourceProfiles. + registerGauge("numberExecutorsToAdd", + numExecutorsToAddPerResourceProfileId.values.sum, 0) registerGauge("numberExecutorsPendingToRemove", executorMonitor.pendingRemovalCount, 0) registerGauge("numberAllExecutors", executorMonitor.executorCount, 0) - registerGauge("numberTargetExecutors", numExecutorsTarget, 0) - registerGauge("numberMaxNeededExecutors", maxNumExecutorsNeeded(), 0) + registerGauge("numberTargetExecutors", + numExecutorsTargetPerResourceProfileId.values.sum, 0) + registerGauge("numberMaxNeededExecutors", numExecutorsTargetPerResourceProfileId.keys + .map(maxNumExecutorsNeededPerResourceProfile(_)).sum, 0) } } private object ExecutorAllocationManager { val NOT_SET = Long.MaxValue + + // helper case class for requesting executors, here to be visible for testing + private[spark] case class TargetNumUpdates(delta: Int, oldNumExecutorsTarget: Int) + } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 91188d58f4201..a47136ea36736 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -25,6 +25,7 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReferenc import scala.collection.JavaConverters._ import scala.collection.Map +import scala.collection.immutable import scala.collection.mutable.HashMap import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} @@ -53,7 +54,7 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.metrics.source.JVMCPUSource import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ -import org.apache.spark.resource.{ResourceID, ResourceInformation} +import org.apache.spark.resource._ import org.apache.spark.resource.ResourceUtils._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ @@ -219,9 +220,10 @@ class SparkContext(config: SparkConf) extends Logging { private var _shutdownHookRef: AnyRef = _ private var _statusStore: AppStatusStore = _ private var _heartbeater: Heartbeater = _ - private var _resources: scala.collection.immutable.Map[String, ResourceInformation] = _ + private var _resources: immutable.Map[String, ResourceInformation] = _ private var _shuffleDriverComponents: ShuffleDriverComponents = _ private var _plugins: Option[PluginContainer] = None + private var _resourceProfileManager: ResourceProfileManager = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -343,6 +345,8 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def executorAllocationManager: Option[ExecutorAllocationManager] = _executorAllocationManager + private[spark] def resourceProfileManager: ResourceProfileManager = _resourceProfileManager + private[spark] def cleaner: Option[ContextCleaner] = _cleaner private[spark] var checkpointDir: Option[String] = None @@ -451,6 +455,7 @@ class SparkContext(config: SparkConf) extends Logging { } _listenerBus = new LiveListenerBus(_conf) + _resourceProfileManager = new ResourceProfileManager(_conf) // Initialize the app status store and listener before SparkEnv is created so that it gets // all events. @@ -611,7 +616,7 @@ class SparkContext(config: SparkConf) extends Logging { case b: ExecutorAllocationClient => Some(new ExecutorAllocationManager( schedulerBackend.asInstanceOf[ExecutorAllocationClient], listenerBus, _conf, - cleaner = cleaner)) + cleaner = cleaner, resourceProfileManager = resourceProfileManager)) case _ => None } @@ -1622,7 +1627,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Update the cluster manager on our scheduling needs. Three bits of information are included - * to help it make decisions. + * to help it make decisions. This applies to the default ResourceProfile. * @param numExecutors The total number of executors we'd like to have. The cluster manager * shouldn't kill any running executor to reach this number, but, * if all existing executors were to die, this is the number of executors @@ -1638,11 +1643,16 @@ class SparkContext(config: SparkConf) extends Logging { def requestTotalExecutors( numExecutors: Int, localityAwareTasks: Int, - hostToLocalTaskCount: scala.collection.immutable.Map[String, Int] + hostToLocalTaskCount: immutable.Map[String, Int] ): Boolean = { schedulerBackend match { case b: ExecutorAllocationClient => - b.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount) + // this is being applied to the default resource profile, would need to add api to support + // others + val defaultProfId = resourceProfileManager.defaultResourceProfile.id + b.requestTotalExecutors(immutable.Map(defaultProfId-> numExecutors), + immutable.Map(localityAwareTasks -> defaultProfId), + immutable.Map(defaultProfId -> hostToLocalTaskCount)) case _ => logWarning("Requesting executors is not supported by current scheduler.") false @@ -2036,6 +2046,7 @@ class SparkContext(config: SparkConf) extends Logging { // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this // `SparkContext` is stopped. localProperties.remove() + ResourceProfile.clearDefaultProfile() // Unset YARN mode system env variable, to allow switching between cluster types. SparkContext.clearActiveContext() logInfo("Successfully stopped SparkContext") @@ -2771,109 +2782,34 @@ object SparkContext extends Logging { // When running locally, don't try to re-execute tasks on failure. val MAX_LOCAL_TASK_FAILURES = 1 - // Ensure that executor's resources satisfies one or more tasks requirement. - def checkResourcesPerTask(clusterMode: Boolean, executorCores: Option[Int]): Unit = { + // Ensure that default executor's resources satisfies one or more tasks requirement. + // This function is for cluster managers that don't set the executor cores config, for + // others its checked in ResourceProfile. + def checkResourcesPerTask(executorCores: Int): Unit = { val taskCores = sc.conf.get(CPUS_PER_TASK) - val execCores = if (clusterMode) { - executorCores.getOrElse(sc.conf.get(EXECUTOR_CORES)) - } else { - executorCores.get - } - // some cluster managers don't set the EXECUTOR_CORES config by default (standalone - // and mesos coarse grained), so we can't rely on that config for those. - val shouldCheckExecCores = executorCores.isDefined || sc.conf.contains(EXECUTOR_CORES) || - (master.equalsIgnoreCase("yarn") || master.startsWith("k8s")) - - // Number of cores per executor must meet at least one task requirement. - if (shouldCheckExecCores && execCores < taskCores) { - throw new SparkException(s"The number of cores per executor (=$execCores) has to be >= " + - s"the task config: ${CPUS_PER_TASK.key} = $taskCores when run on $master.") - } - - // Calculate the max slots each executor can provide based on resources available on each - // executor and resources required by each task. - val taskResourceRequirements = parseResourceRequirements(sc.conf, SPARK_TASK_PREFIX) - val executorResourcesAndAmounts = parseAllResourceRequests(sc.conf, SPARK_EXECUTOR_PREFIX) - .map(request => (request.id.resourceName, request.amount)).toMap - - var (numSlots, limitingResourceName) = if (shouldCheckExecCores) { - (execCores / taskCores, "CPU") - } else { - (-1, "") - } - - taskResourceRequirements.foreach { taskReq => - // Make sure the executor resources were specified through config. - val execAmount = executorResourcesAndAmounts.getOrElse(taskReq.resourceName, - throw new SparkException("The executor resource config: " + - new ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName).amountConf + - " needs to be specified since a task requirement config: " + - new ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName).amountConf + - " was specified") - ) - // Make sure the executor resources are large enough to launch at least one task. - if (execAmount < taskReq.amount) { - throw new SparkException("The executor resource config: " + - new ResourceID(SPARK_EXECUTOR_PREFIX, taskReq.resourceName).amountConf + - s" = $execAmount has to be >= the requested amount in task resource config: " + - new ResourceID(SPARK_TASK_PREFIX, taskReq.resourceName).amountConf + - s" = ${taskReq.amount}") - } - // Compare and update the max slots each executor can provide. - // If the configured amount per task was < 1.0, a task is subdividing - // executor resources. If the amount per task was > 1.0, the task wants - // multiple executor resources. - val resourceNumSlots = Math.floor(execAmount * taskReq.numParts / taskReq.amount).toInt - if (resourceNumSlots < numSlots) { - if (shouldCheckExecCores) { - throw new IllegalArgumentException("The number of slots on an executor has to be " + - "limited by the number of cores, otherwise you waste resources and " + - "dynamic allocation doesn't work properly. Your configuration has " + - s"core/task cpu slots = ${numSlots} and " + - s"${taskReq.resourceName} = ${resourceNumSlots}. " + - "Please adjust your configuration so that all resources require same number " + - "of executor slots.") - } - numSlots = resourceNumSlots - limitingResourceName = taskReq.resourceName - } - } - if(!shouldCheckExecCores && Utils.isDynamicAllocationEnabled(sc.conf)) { - // if we can't rely on the executor cores config throw a warning for user - logWarning("Please ensure that the number of slots available on your " + - "executors is limited by the number of cores to task cpus and not another " + - "custom resource. If cores is not the limiting resource then dynamic " + - "allocation will not work properly!") - } - // warn if we would waste any resources due to another resource limiting the number of - // slots on an executor - taskResourceRequirements.foreach { taskReq => - val execAmount = executorResourcesAndAmounts(taskReq.resourceName) - if ((numSlots * taskReq.amount / taskReq.numParts) < execAmount) { - val taskReqStr = if (taskReq.numParts > 1) { - s"${taskReq.amount}/${taskReq.numParts}" - } else { - s"${taskReq.amount}" - } - val resourceNumSlots = Math.floor(execAmount * taskReq.numParts / taskReq.amount).toInt - val message = s"The configuration of resource: ${taskReq.resourceName} " + - s"(exec = ${execAmount}, task = ${taskReqStr}, " + - s"runnable tasks = ${resourceNumSlots}) will " + - s"result in wasted resources due to resource ${limitingResourceName} limiting the " + - s"number of runnable tasks per executor to: ${numSlots}. Please adjust " + - s"your configuration." - if (Utils.isTesting) { - throw new SparkException(message) - } else { - logWarning(message) - } - } + validateTaskCpusLargeEnough(executorCores, taskCores) + val defaultProf = sc.resourceProfileManager.defaultResourceProfile + // TODO - this is temporary until all of stage level scheduling feature is integrated, + // fail if any other resource limiting due to dynamic allocation and scheduler using + // slots based on cores + val cpuSlots = executorCores/taskCores + val limitingResource = defaultProf.limitingResource(sc.conf) + if (limitingResource.nonEmpty && !limitingResource.equals(ResourceProfile.CPUS) && + defaultProf.maxTasksPerExecutor(sc.conf) < cpuSlots) { + throw new IllegalArgumentException("The number of slots on an executor has to be " + + "limited by the number of cores, otherwise you waste resources and " + + "dynamic allocation doesn't work properly. Your configuration has " + + s"core/task cpu slots = ${cpuSlots} and " + + s"${limitingResource} = " + + s"${defaultProf.maxTasksPerExecutor(sc.conf)}. Please adjust your configuration " + + "so that all resources require same number of executor slots.") } + ResourceUtils.warnOnWastedResources(defaultProf, sc.conf, Some(executorCores)) } master match { case "local" => - checkResourcesPerTask(clusterMode = false, Some(1)) + checkResourcesPerTask(1) val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1) scheduler.initialize(backend) @@ -2886,7 +2822,7 @@ object SparkContext extends Logging { if (threadCount <= 0) { throw new SparkException(s"Asked to run locally with $threadCount threads") } - checkResourcesPerTask(clusterMode = false, Some(threadCount)) + checkResourcesPerTask(threadCount) val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) @@ -2897,14 +2833,13 @@ object SparkContext extends Logging { // local[*, M] means the number of cores on the computer with M failures // local[N, M] means exactly N threads with M failures val threadCount = if (threads == "*") localCpuCount else threads.toInt - checkResourcesPerTask(clusterMode = false, Some(threadCount)) + checkResourcesPerTask(threadCount) val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount) scheduler.initialize(backend) (backend, scheduler) case SPARK_REGEX(sparkUrl) => - checkResourcesPerTask(clusterMode = true, None) val scheduler = new TaskSchedulerImpl(sc) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls) @@ -2912,7 +2847,7 @@ object SparkContext extends Logging { (backend, scheduler) case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - checkResourcesPerTask(clusterMode = true, Some(coresPerSlave.toInt)) + checkResourcesPerTask(coresPerSlave.toInt) // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. val memoryPerSlaveInt = memoryPerSlave.toInt if (sc.executorMemory > memoryPerSlaveInt) { @@ -2941,7 +2876,6 @@ object SparkContext extends Logging { (backend, scheduler) case masterUrl => - checkResourcesPerTask(clusterMode = true, None) val cm = getClusterManager(masterUrl) match { case Some(clusterMgr) => clusterMgr case None => throw new SparkException("Could not parse Master URL: '" + master + "'") diff --git a/core/src/main/scala/org/apache/spark/internal/config/Tests.scala b/core/src/main/scala/org/apache/spark/internal/config/Tests.scala index 21660ab3a9512..51df73ebde07d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Tests.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Tests.scala @@ -53,4 +53,13 @@ private[spark] object Tests { val TEST_N_CORES_EXECUTOR = ConfigBuilder("spark.testing.nCoresPerExecutor") .intConf .createWithDefault(2) + + val RESOURCES_WARNING_TESTING = + ConfigBuilder("spark.resources.warnings.testing").booleanConf.createWithDefault(false) + + val RESOURCE_PROFILE_MANAGER_TESTING = + ConfigBuilder("spark.testing.resourceProfileManager") + .booleanConf + .createWithDefault(false) + } diff --git a/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequests.scala b/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequests.scala index d345674d6635c..d4c29f9a70c44 100644 --- a/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequests.scala +++ b/core/src/main/scala/org/apache/spark/resource/ExecutorResourceRequests.scala @@ -109,7 +109,7 @@ private[spark] class ExecutorResourceRequests() extends Serializable { discoveryScript: String = "", vendor: String = ""): this.type = { // a bit weird but for Java api use empty string as meaning None because empty - // string is otherwise invalid for those paramters anyway + // string is otherwise invalid for those parameters anyway val req = new ExecutorResourceRequest(resourceName, amount, discoveryScript, vendor) _executorResources.put(resourceName, req) this diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala index 14019d27fc2e6..03dcf5e317798 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala @@ -22,12 +22,14 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ +import scala.collection.mutable -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.Evolving import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Python.PYSPARK_EXECUTOR_MEMORY +import org.apache.spark.util.Utils /** * Resource profile to associate with an RDD. A ResourceProfile allows the user to @@ -42,6 +44,13 @@ class ResourceProfile( // _id is only a var for testing purposes private var _id = ResourceProfile.getNextProfileId + // This is used for any resources that use fractional amounts, the key is the resource name + // and the value is the number of tasks that can share a resource address. For example, + // if the user says task gpu amount is 0.5, that results in 2 tasks per resource address. + private var _executorResourceSlotsPerAddr: Option[Map[String, Int]] = None + private var _limitingResource: Option[String] = None + private var _maxTasksPerExecutor: Option[Int] = None + private var _coresLimitKnown: Boolean = false def id: Int = _id @@ -67,6 +76,138 @@ class ResourceProfile( taskResources.get(ResourceProfile.CPUS).map(_.amount.toInt) } + private[spark] def getNumSlotsPerAddress(resource: String, sparkConf: SparkConf): Int = { + _executorResourceSlotsPerAddr.getOrElse { + calculateTasksAndLimitingResource(sparkConf) + } + _executorResourceSlotsPerAddr.get.getOrElse(resource, + throw new SparkException(s"Resource $resource doesn't exist in profile id: $id")) + } + + // Maximum tasks you could put on an executor with this profile based on the limiting resource. + // If the executor cores config is not present this value is based on the other resources + // available or 1 if no other resources. You need to check the isCoresLimitKnown to + // calculate proper value. + private[spark] def maxTasksPerExecutor(sparkConf: SparkConf): Int = { + _maxTasksPerExecutor.getOrElse { + calculateTasksAndLimitingResource(sparkConf) + _maxTasksPerExecutor.get + } + } + + // Returns whether the executor cores was available to use to calculate the max tasks + // per executor and limiting resource. Some cluster managers (like standalone and coarse + // grained mesos) don't use the cores config by default so we can't use it to calculate slots. + private[spark] def isCoresLimitKnown: Boolean = _coresLimitKnown + + // The resource that has the least amount of slots per executor. Its possible multiple or all + // resources result in same number of slots and this could be any of those. + // If the executor cores config is not present this value is based on the other resources + // available or empty string if no other resources. You need to check the isCoresLimitKnown to + // calculate proper value. + private[spark] def limitingResource(sparkConf: SparkConf): String = { + _limitingResource.getOrElse { + calculateTasksAndLimitingResource(sparkConf) + _limitingResource.get + } + } + + // executor cores config is not set for some masters by default and the default value + // only applies to yarn/k8s + private def shouldCheckExecutorCores(sparkConf: SparkConf): Boolean = { + val master = sparkConf.getOption("spark.master") + sparkConf.contains(EXECUTOR_CORES) || + (master.isDefined && (master.get.equalsIgnoreCase("yarn") || master.get.startsWith("k8s"))) + } + + /** + * Utility function to calculate the number of tasks you can run on a single Executor based + * on the task and executor resource requests in the ResourceProfile. This will be based + * off the resource that is most restrictive. For instance, if the executor + * request is for 4 cpus and 2 gpus and your task request is for 1 cpu and 1 gpu each, the + * limiting resource is gpu and the number of tasks you can run on a single executor is 2. + * This function also sets the limiting resource, isCoresLimitKnown and number of slots per + * resource address. + */ + private def calculateTasksAndLimitingResource(sparkConf: SparkConf): Unit = synchronized { + val shouldCheckExecCores = shouldCheckExecutorCores(sparkConf) + var (taskLimit, limitingResource) = if (shouldCheckExecCores) { + val cpusPerTask = taskResources.get(ResourceProfile.CPUS) + .map(_.amount).getOrElse(sparkConf.get(CPUS_PER_TASK).toDouble).toInt + assert(cpusPerTask > 0, "CPUs per task configuration has to be > 0") + val coresPerExecutor = getExecutorCores.getOrElse(sparkConf.get(EXECUTOR_CORES)) + _coresLimitKnown = true + ResourceUtils.validateTaskCpusLargeEnough(coresPerExecutor, cpusPerTask) + val tasksBasedOnCores = coresPerExecutor / cpusPerTask + // Note that if the cores per executor aren't set properly this calculation could be off, + // we default it to just be 1 in order to allow checking of the rest of the custom + // resources. We set the limit based on the other resources available. + (tasksBasedOnCores, ResourceProfile.CPUS) + } else { + (-1, "") + } + val numPartsPerResourceMap = new mutable.HashMap[String, Int] + numPartsPerResourceMap(ResourceProfile.CORES) = 1 + val taskResourcesToCheck = new mutable.HashMap[String, TaskResourceRequest] + taskResourcesToCheck ++= ResourceProfile.getCustomTaskResources(this) + val execResourceToCheck = ResourceProfile.getCustomExecutorResources(this) + execResourceToCheck.foreach { case (rName, execReq) => + val taskReq = taskResources.get(rName).map(_.amount).getOrElse(0.0) + numPartsPerResourceMap(rName) = 1 + if (taskReq > 0.0) { + if (taskReq > execReq.amount) { + throw new SparkException(s"The executor resource: $rName, amount: ${execReq.amount} " + + s"needs to be >= the task resource request amount of $taskReq") + } + val (numPerTask, parts) = ResourceUtils.calculateAmountAndPartsForFraction(taskReq) + numPartsPerResourceMap(rName) = parts + val numTasks = ((execReq.amount * parts) / numPerTask).toInt + if (taskLimit == -1 || numTasks < taskLimit) { + if (shouldCheckExecCores) { + // TODO - until resource profiles full implemented we need to error if cores not + // limiting resource because the scheduler code uses that for slots + throw new IllegalArgumentException("The number of slots on an executor has to be " + + "limited by the number of cores, otherwise you waste resources and " + + "dynamic allocation doesn't work properly. Your configuration has " + + s"core/task cpu slots = ${taskLimit} and " + + s"${execReq.resourceName} = ${numTasks}. " + + "Please adjust your configuration so that all resources require same number " + + "of executor slots.") + } + limitingResource = rName + taskLimit = numTasks + } + taskResourcesToCheck -= rName + } else { + logWarning(s"The executor resource config for resource: $rName was specified but " + + "no corresponding task resource request was specified.") + } + } + if(!shouldCheckExecCores) { + // if we can't rely on the executor cores config throw a warning for user + logWarning("Please ensure that the number of slots available on your " + + "executors is limited by the number of cores to task cpus and not another " + + "custom resource. If cores is not the limiting resource then dynamic " + + "allocation will not work properly!") + } + if (taskResourcesToCheck.nonEmpty) { + throw new SparkException("No executor resource configs were not specified for the " + + s"following task configs: ${taskResourcesToCheck.keys.mkString(",")}") + } + logInfo(s"Limiting resource is $limitingResource at $taskLimit tasks per executor") + _executorResourceSlotsPerAddr = Some(numPartsPerResourceMap.toMap) + _maxTasksPerExecutor = if (taskLimit == -1) Some(1) else Some(taskLimit) + _limitingResource = Some(limitingResource) + if (shouldCheckExecCores) { + ResourceUtils.warnOnWastedResources(this, sparkConf) + } + } + + // to be used only by history server for reconstruction from events + private[spark] def setResourceProfileId(id: Int): Unit = { + _id = id + } + // testing only private[spark] def setToDefaultProfile(): Unit = { _id = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID @@ -123,7 +264,7 @@ object ResourceProfile extends Logging { val taskResources = getDefaultTaskResources(conf) val executorResources = getDefaultExecutorResources(conf) val defProf = new ResourceProfile(executorResources, taskResources) - defProf.setToDefaultProfile + defProf.setToDefaultProfile() defaultProfile = Some(defProf) logInfo("Default ResourceProfile created, executor resources: " + s"${defProf.executorResources}, task resources: " + @@ -157,13 +298,12 @@ object ResourceProfile extends Logging { // for testing only private[spark] def reInitDefaultProfile(conf: SparkConf): Unit = { - clearDefaultProfile + clearDefaultProfile() // force recreate it after clearing getOrCreateDefaultProfile(conf) } - // for testing only - private[spark] def clearDefaultProfile: Unit = { + private[spark] def clearDefaultProfile(): Unit = { DEFAULT_PROFILE_LOCK.synchronized { defaultProfile = None } diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfileBuilder.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfileBuilder.scala index 0d55c176eeb65..26f23f4bf0476 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfileBuilder.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfileBuilder.scala @@ -31,7 +31,7 @@ import org.apache.spark.annotation.Evolving * requirements between stages. */ @Evolving -class ResourceProfileBuilder() { +private[spark] class ResourceProfileBuilder() { private val _taskResources = new ConcurrentHashMap[String, TaskResourceRequest]() private val _executorResources = new ConcurrentHashMap[String, ExecutorResourceRequest]() diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala new file mode 100644 index 0000000000000..06db9468c451e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.resource + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.annotation.Evolving +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Tests._ +import org.apache.spark.util.Utils +import org.apache.spark.util.Utils.isTesting + +/** + * Manager of resource profiles. The manager allows one place to keep the actual ResourceProfiles + * and everywhere else we can use the ResourceProfile Id to save on space. + * Note we never remove a resource profile at this point. Its expected this number if small + * so this shouldn't be much overhead. + */ +@Evolving +private[spark] class ResourceProfileManager(sparkConf: SparkConf) extends Logging { + private val resourceProfileIdToResourceProfile = new ConcurrentHashMap[Int, ResourceProfile]() + + private val defaultProfile = ResourceProfile.getOrCreateDefaultProfile(sparkConf) + addResourceProfile(defaultProfile) + + def defaultResourceProfile: ResourceProfile = defaultProfile + + private val taskCpusDefaultProfile = defaultProfile.getTaskCpus.get + private val dynamicEnabled = Utils.isDynamicAllocationEnabled(sparkConf) + private val master = sparkConf.getOption("spark.master") + private val isNotYarn = master.isDefined && !master.get.equals("yarn") + private val errorForTesting = !isTesting || sparkConf.get(RESOURCE_PROFILE_MANAGER_TESTING) + + // If we use anything except the default profile, its only supported on YARN right now. + // Throw an exception if not supported. + private[spark] def isSupported(rp: ResourceProfile): Boolean = { + val isNotDefaultProfile = rp.id != ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID + val notYarnAndNotDefaultProfile = isNotDefaultProfile && isNotYarn + val YarnNotDynAllocAndNotDefaultProfile = isNotDefaultProfile && !isNotYarn && !dynamicEnabled + if (errorForTesting && (notYarnAndNotDefaultProfile || YarnNotDynAllocAndNotDefaultProfile)) { + throw new SparkException("ResourceProfiles are only supported on YARN with dynamic " + + "allocation enabled.") + } + true + } + + def addResourceProfile(rp: ResourceProfile): Unit = { + isSupported(rp) + // force the computation of maxTasks and limitingResource now so we don't have cost later + rp.limitingResource(sparkConf) + logInfo(s"Adding ResourceProfile id: ${rp.id}") + resourceProfileIdToResourceProfile.putIfAbsent(rp.id, rp) + } + + /* + * Gets the ResourceProfile associated with the id, if a profile doesn't exist + * it returns the default ResourceProfile created from the application level configs. + */ + def resourceProfileFromId(rpId: Int): ResourceProfile = { + val rp = resourceProfileIdToResourceProfile.get(rpId) + if (rp == null) { + throw new SparkException(s"ResourceProfileId $rpId not found!") + } + rp + } + + def taskCpusForProfileId(rpId: Int): Int = { + resourceProfileFromId(rpId).getTaskCpus.getOrElse(taskCpusDefaultProfile) + } +} diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala index 7dd7fc1b99353..cdb761c7566e7 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceUtils.scala @@ -29,7 +29,8 @@ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.resource.ResourceDiscoveryPlugin import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{RESOURCES_DISCOVERY_PLUGIN, SPARK_TASK_PREFIX} +import org.apache.spark.internal.config.{CPUS_PER_TASK, EXECUTOR_CORES, RESOURCES_DISCOVERY_PLUGIN, SPARK_TASK_PREFIX} +import org.apache.spark.internal.config.Tests.{RESOURCES_WARNING_TESTING} import org.apache.spark.util.Utils /** @@ -161,19 +162,23 @@ private[spark] object ResourceUtils extends Logging { } // Used to take a fraction amount from a task resource requirement and split into a real - // integer amount and the number of parts expected. For instance, if the amount is 0.5, - // the we get (1, 2) back out. - // Returns tuple of (amount, numParts) - def calculateAmountAndPartsForFraction(amount: Double): (Int, Int) = { - val parts = if (amount <= 0.5) { - Math.floor(1.0 / amount).toInt - } else if (amount % 1 != 0) { + // integer amount and the number of slots per address. For instance, if the amount is 0.5, + // the we get (1, 2) back out. This indicates that for each 1 address, it has 2 slots per + // address, which allows you to put 2 tasks on that address. Note if amount is greater + // than 1, then the number of slots per address has to be 1. This would indicate that a + // would have multiple addresses assigned per task. This can be used for calculating + // the number of tasks per executor -> (executorAmount * numParts) / (integer amount). + // Returns tuple of (integer amount, numParts) + def calculateAmountAndPartsForFraction(doubleAmount: Double): (Int, Int) = { + val parts = if (doubleAmount <= 0.5) { + Math.floor(1.0 / doubleAmount).toInt + } else if (doubleAmount % 1 != 0) { throw new SparkException( - s"The resource amount ${amount} must be either <= 0.5, or a whole number.") + s"The resource amount ${doubleAmount} must be either <= 0.5, or a whole number.") } else { 1 } - (Math.ceil(amount).toInt, parts) + (Math.ceil(doubleAmount).toInt, parts) } // Add any task resource requests from the spark conf to the TaskResourceRequests passed in @@ -382,6 +387,90 @@ private[spark] object ResourceUtils extends Logging { s"${resourceRequest.id.resourceName}") } + def validateTaskCpusLargeEnough(execCores: Int, taskCpus: Int): Boolean = { + // Number of cores per executor must meet at least one task requirement. + if (execCores < taskCpus) { + throw new SparkException(s"The number of cores per executor (=$execCores) has to be >= " + + s"the number of cpus per task = $taskCpus.") + } + true + } + + // the option executor cores parameter is by the different local modes since it not configured + // via the config + def warnOnWastedResources( + rp: ResourceProfile, + sparkConf: SparkConf, + execCores: Option[Int] = None): Unit = { + // There have been checks on the ResourceProfile to make sure the executor resources were + // specified and are large enough if any task resources were specified. + // Now just do some sanity test and log warnings when it looks like the user will + // waste some resources. + val coresKnown = rp.isCoresLimitKnown + var limitingResource = rp.limitingResource(sparkConf) + var maxTaskPerExec = rp.maxTasksPerExecutor(sparkConf) + val taskCpus = rp.getTaskCpus.getOrElse(sparkConf.get(CPUS_PER_TASK)) + val cores = if (execCores.isDefined) { + execCores.get + } else if (coresKnown) { + rp.getExecutorCores.getOrElse(sparkConf.get(EXECUTOR_CORES)) + } else { + // can't calculate cores limit + return + } + // when executor cores config isn't set, we can't calculate the real limiting resource + // and number of tasks per executor ahead of time, so calculate it now. + if (!coresKnown) { + val numTasksPerExecCores = cores / taskCpus + val numTasksPerExecCustomResource = rp.maxTasksPerExecutor(sparkConf) + if (limitingResource.isEmpty || + (limitingResource.nonEmpty && numTasksPerExecCores < numTasksPerExecCustomResource)) { + limitingResource = ResourceProfile.CPUS + maxTaskPerExec = numTasksPerExecCores + } + } + val taskReq = ResourceProfile.getCustomTaskResources(rp) + val execReq = ResourceProfile.getCustomExecutorResources(rp) + + if (limitingResource.nonEmpty && !limitingResource.equals(ResourceProfile.CPUS)) { + if ((taskCpus * maxTaskPerExec) < cores) { + val resourceNumSlots = Math.floor(cores/taskCpus).toInt + val message = s"The configuration of cores (exec = ${cores} " + + s"task = ${taskCpus}, runnable tasks = ${resourceNumSlots}) will " + + s"result in wasted resources due to resource ${limitingResource} limiting the " + + s"number of runnable tasks per executor to: ${maxTaskPerExec}. Please adjust " + + "your configuration." + if (sparkConf.get(RESOURCES_WARNING_TESTING)) { + throw new SparkException(message) + } else { + logWarning(message) + } + } + } + + taskReq.foreach { case (rName, treq) => + val execAmount = execReq(rName).amount + val numParts = rp.getNumSlotsPerAddress(rName, sparkConf) + // handle fractional + val taskAmount = if (numParts > 1) 1 else treq.amount + if (maxTaskPerExec < (execAmount * numParts / taskAmount)) { + val taskReqStr = s"${taskAmount}/${numParts}" + val resourceNumSlots = Math.floor(execAmount * numParts / taskAmount).toInt + val message = s"The configuration of resource: ${treq.resourceName} " + + s"(exec = ${execAmount}, task = ${taskReqStr}, " + + s"runnable tasks = ${resourceNumSlots}) will " + + s"result in wasted resources due to resource ${limitingResource} limiting the " + + s"number of runnable tasks per executor to: ${maxTaskPerExec}. Please adjust " + + "your configuration." + if (sparkConf.get(RESOURCES_WARNING_TESTING)) { + throw new SparkException(message) + } else { + logWarning(message) + } + } + } + } + // known types of resources final val GPU: String = "gpu" final val FPGA: String = "fpga" diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7bf363dd71c1b..fd5c3e0827bf9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -37,7 +37,8 @@ import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -391,7 +392,8 @@ private[spark] class DAGScheduler( val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() val stage = new ShuffleMapStage( - id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker) + id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker, + ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) stageIdToStage(id) = stage shuffleIdToMapStage(shuffleDep.shuffleId) = stage @@ -453,7 +455,8 @@ private[spark] class DAGScheduler( checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size) val parents = getOrCreateParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() - val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite) + val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite, + ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala index d1687830ff7bf..7fdc3186e86bd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala @@ -34,8 +34,9 @@ private[spark] class ResultStage( val partitions: Array[Int], parents: List[Stage], firstJobId: Int, - callSite: CallSite) - extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite) { + callSite: CallSite, + resourceProfileId: Int) + extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite, resourceProfileId) { /** * The active job for this result stage. Will be empty if the job has already finished diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 1b44d0aee3195..be1984de9837f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -42,8 +42,9 @@ private[spark] class ShuffleMapStage( firstJobId: Int, callSite: CallSite, val shuffleDep: ShuffleDependency[_, _, _], - mapOutputTrackerMaster: MapOutputTrackerMaster) - extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) { + mapOutputTrackerMaster: MapOutputTrackerMaster, + resourceProfileId: Int) + extends Stage(id, rdd, numTasks, parents, firstJobId, callSite, resourceProfileId) { private[this] var _mapStageJobs: List[ActiveJob] = Nil diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index a9f72eae71368..ae7924d66a301 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -59,7 +59,8 @@ private[scheduler] abstract class Stage( val numTasks: Int, val parents: List[Stage], val firstJobId: Int, - val callSite: CallSite) + val callSite: CallSite, + val resourceProfileId: Int) extends Logging { val numPartitions = rdd.partitions.length @@ -79,7 +80,8 @@ private[scheduler] abstract class Stage( * StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts * have been created). */ - private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) + private var _latestInfo: StageInfo = + StageInfo.fromStage(this, nextAttemptId, resourceProfileId = resourceProfileId) /** * Set of stage attempt IDs that have failed. We keep track of these failures in order to avoid @@ -100,7 +102,8 @@ private[scheduler] abstract class Stage( val metrics = new TaskMetrics metrics.register(rdd.sparkContext) _latestInfo = StageInfo.fromStage( - this, nextAttemptId, Some(numPartitionsToCompute), metrics, taskLocalityPreferences) + this, nextAttemptId, Some(numPartitionsToCompute), metrics, taskLocalityPreferences, + resourceProfileId = resourceProfileId) nextAttemptId += 1 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index fdc50328b43d8..556478d83cf39 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -38,7 +38,8 @@ class StageInfo( val details: String, val taskMetrics: TaskMetrics = null, private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty, - private[spark] val shuffleDepId: Option[Int] = None) { + private[spark] val shuffleDepId: Option[Int] = None, + val resourceProfileId: Int) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None /** Time when all tasks in the stage completed or when the stage was cancelled. */ @@ -87,7 +88,8 @@ private[spark] object StageInfo { attemptId: Int, numTasks: Option[Int] = None, taskMetrics: TaskMetrics = null, - taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty + taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty, + resourceProfileId: Int ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos @@ -105,6 +107,7 @@ private[spark] object StageInfo { stage.details, taskMetrics, taskLocalityPreferences, - shuffleDepId) + shuffleDepId, + resourceProfileId) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 6a1d460e6a9d9..bf92081d13907 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -384,7 +384,9 @@ private[spark] class TaskSchedulerImpl( */ private def resourcesMeetTaskRequirements(resources: Map[String, Buffer[String]]): Boolean = { val resourcesFree = resources.map(r => r._1 -> r._2.length) - ResourceUtils.resourcesMeetRequirements(resourcesFree, resourcesReqsPerTask) + val meetsReqs = ResourceUtils.resourcesMeetRequirements(resourcesFree, resourcesReqsPerTask) + logDebug(s"Resources meet task requirements is: $meetsReqs") + meetsReqs } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 55f4005ef1b45..63aa04986b073 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,13 +69,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp conf.get(SCHEDULER_MAX_REGISTERED_RESOURCE_WAITING_TIME)) private val createTimeNs = System.nanoTime() - private val taskResourceNumParts: Map[String, Int] = - if (scheduler.resourcesReqsPerTask != null) { - scheduler.resourcesReqsPerTask.map(req => req.resourceName -> req.numParts).toMap - } else { - Map.empty - } - // Accessing `executorDataMap` in the inherited methods from ThreadSafeRpcEndpoint doesn't need // any protection. But accessing `executorDataMap` out of the inherited methods must be // protected by `CoarseGrainedSchedulerBackend.this`. Besides, `executorDataMap` should only @@ -83,13 +76,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] - // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]] - @GuardedBy("CoarseGrainedSchedulerBackend.this") - private var requestedTotalExecutors = 0 - - // Number of executors requested from the cluster manager that have not registered yet + // Number of executors for each ResourceProfile requested by the cluster + // manager, [[ExecutorAllocationManager]] @GuardedBy("CoarseGrainedSchedulerBackend.this") - private var numPendingExecutors = 0 + private val requestedTotalExecutorsPerResourceProfile = new HashMap[ResourceProfile, Int] private val listenerBus = scheduler.sc.listenerBus @@ -102,13 +92,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors that have been lost, but for which we don't yet know the real exit reason. private val executorsPendingLossReason = new HashSet[String] - // A map to store hostname with its possible task number running on it + // A map of ResourceProfile id to map of hostname with its possible task number running on it @GuardedBy("CoarseGrainedSchedulerBackend.this") - protected var hostToLocalTaskCount: Map[String, Int] = Map.empty + protected var rpHostToLocalTaskCount: Map[Int, Map[String, Int]] = Map.empty - // The number of pending tasks which is locality required + // The number of pending tasks per ResourceProfile id which is locality required @GuardedBy("CoarseGrainedSchedulerBackend.this") - protected var localityAwareTasks = 0 + protected var numLocalityAwareTasksPerResourceProfileId = Map.empty[Int, Int] // The num of current max ExecutorId used to re-register appMaster @volatile protected var currentExecutorIdCounter = 0 @@ -223,16 +213,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } else { context.senderAddress } - logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId") + logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId, " + + s" ResourceProfileId $resourceProfileId") addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) - val resourcesInfo = resources.map{ case (k, v) => - (v.name, - new ExecutorResourceInfo(v.name, v.addresses, - // tell the executor it can schedule resources up to numParts times, - // as configured by the user, or set to 1 as that is the default (1 task/resource) - taskResourceNumParts.getOrElse(v.name, 1))) + val resourcesInfo = resources.map { case (rName, info) => + // tell the executor it can schedule resources up to numParts times, + // as configured by the user, or set to 1 as that is the default (1 task/resource) + val numParts = scheduler.sc.resourceProfileManager + .resourceProfileFromId(resourceProfileId).getNumSlotsPerAddress(rName, conf) + (info.name, new ExecutorResourceInfo(info.name, info.addresses, numParts)) } val data = new ExecutorData(executorRef, executorAddress, hostname, 0, cores, logUrlHandler.applyPattern(logUrls, attributes), attributes, @@ -244,10 +235,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (currentExecutorIdCounter < executorId.toInt) { currentExecutorIdCounter = executorId.toInt } - if (numPendingExecutors > 0) { - numPendingExecutors -= 1 - logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") - } } // Note: some tests expect the reply to come after we put the executor in the map context.reply(true) @@ -271,10 +258,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp context.reply(true) case RetrieveSparkAppConfig(resourceProfileId) => - // note this will be updated in later prs to get the ResourceProfile from a - // ResourceProfileManager that matches the resource profile id - // for now just use default profile - val rp = ResourceProfile.getOrCreateDefaultProfile(conf) + val rp = scheduler.sc.resourceProfileManager.resourceProfileFromId(resourceProfileId) val reply = SparkAppConfig( sparkProperties, SparkEnv.get.securityManager.getIOEncryptionKey(), @@ -494,8 +478,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ protected[scheduler] def reset(): Unit = { val executors: Set[String] = synchronized { - requestedTotalExecutors = 0 - numPendingExecutors = 0 + requestedTotalExecutorsPerResourceProfile.clear() executorDataMap.keys.toSet } @@ -577,12 +560,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // this function is for testing only def getExecutorResourceProfileId(executorId: String): Int = synchronized { - val res = executorDataMap.get(executorId) - res.map(_.resourceProfileId).getOrElse(ResourceProfile.UNKNOWN_RESOURCE_PROFILE_ID) + val execDataOption = executorDataMap.get(executorId) + execDataOption.map(_.resourceProfileId).getOrElse(ResourceProfile.UNKNOWN_RESOURCE_PROFILE_ID) } /** - * Request an additional number of executors from the cluster manager. + * Request an additional number of executors from the cluster manager. This is + * requesting against the default ResourceProfile, we will need an API change to + * allow against other profiles. * @return whether the request is acknowledged. */ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = { @@ -594,21 +579,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") val response = synchronized { - requestedTotalExecutors += numAdditionalExecutors - numPendingExecutors += numAdditionalExecutors - logDebug(s"Number of pending executors is now $numPendingExecutors") - if (requestedTotalExecutors != - (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { - logDebug( - s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match: - |requestedTotalExecutors = $requestedTotalExecutors - |numExistingExecutors = $numExistingExecutors - |numPendingExecutors = $numPendingExecutors - |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) - } - + val defaultProf = scheduler.sc.resourceProfileManager.defaultResourceProfile + val numExisting = requestedTotalExecutorsPerResourceProfile.getOrElse(defaultProf, 0) + requestedTotalExecutorsPerResourceProfile(defaultProf) = numExisting + numAdditionalExecutors // Account for executors pending to be added or removed - doRequestTotalExecutors(requestedTotalExecutors) + doRequestTotalExecutors(requestedTotalExecutorsPerResourceProfile.toMap) } defaultAskTimeout.awaitResult(response) @@ -617,39 +592,41 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Update the cluster manager on our scheduling needs. Three bits of information are included * to help it make decisions. - * @param numExecutors The total number of executors we'd like to have. The cluster manager - * shouldn't kill any running executor to reach this number, but, - * if all existing executors were to die, this is the number of executors - * we'd want to be allocated. - * @param localityAwareTasks The number of tasks in all active stages that have a locality - * preferences. This includes running, pending, and completed tasks. + * @param resourceProfileToNumExecutors The total number of executors we'd like to have per + * ResourceProfile. The cluster manager shouldn't kill any + * running executor to reach this number, but, if all + * existing executors were to die, this is the number + * of executors we'd want to be allocated. + * @param numLocalityAwareTasksPerResourceProfileId The number of tasks in all active stages that + * have a locality preferences per + * ResourceProfile. This includes running, + * pending, and completed tasks. * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages * that would like to like to run on that host. * This includes running, pending, and completed tasks. * @return whether the request is acknowledged by the cluster manager. */ final override def requestTotalExecutors( - numExecutors: Int, - localityAwareTasks: Int, - hostToLocalTaskCount: Map[String, Int] - ): Boolean = { - if (numExecutors < 0) { + resourceProfileIdToNumExecutors: Map[Int, Int], + numLocalityAwareTasksPerResourceProfileId: Map[Int, Int], + hostToLocalTaskCount: Map[Int, Map[String, Int]] + ): Boolean = { + val totalExecs = resourceProfileIdToNumExecutors.values.sum + if (totalExecs < 0) { throw new IllegalArgumentException( "Attempted to request a negative number of executor(s) " + - s"$numExecutors from the cluster manager. Please specify a positive number!") + s"$totalExecs from the cluster manager. Please specify a positive number!") + } + val resourceProfileToNumExecutors = resourceProfileIdToNumExecutors.map { case (rpid, num) => + (scheduler.sc.resourceProfileManager.resourceProfileFromId(rpid), num) } - val response = synchronized { - this.requestedTotalExecutors = numExecutors - this.localityAwareTasks = localityAwareTasks - this.hostToLocalTaskCount = hostToLocalTaskCount - - numPendingExecutors = - math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0) - - doRequestTotalExecutors(numExecutors) + this.requestedTotalExecutorsPerResourceProfile.clear() + this.requestedTotalExecutorsPerResourceProfile ++= resourceProfileToNumExecutors + this.numLocalityAwareTasksPerResourceProfileId = numLocalityAwareTasksPerResourceProfileId + this.rpHostToLocalTaskCount = hostToLocalTaskCount + doRequestTotalExecutors(requestedTotalExecutorsPerResourceProfile.toMap) } - defaultAskTimeout.awaitResult(response) } @@ -665,7 +642,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * * @return a future whose evaluation indicates whether the request is acknowledged. */ - protected def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = + protected def doRequestTotalExecutors( + resourceProfileToTotalExecs: Map[ResourceProfile, Int]): Future[Boolean] = Future.successful(false) /** @@ -706,20 +684,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // take into account executors that are pending to be added or removed. val adjustTotalExecutors = if (adjustTargetNumExecutors) { - requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) - if (requestedTotalExecutors != - (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { - logDebug( - s"""killExecutors($executorIds, $adjustTargetNumExecutors, $countFailures, $force): - |Executor counts do not match: - |requestedTotalExecutors = $requestedTotalExecutors - |numExistingExecutors = $numExistingExecutors - |numPendingExecutors = $numPendingExecutors - |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + executorsToKill.foreach { exec => + val rpId = executorDataMap(exec).resourceProfileId + val rp = scheduler.sc.resourceProfileManager.resourceProfileFromId(rpId) + if (requestedTotalExecutorsPerResourceProfile.isEmpty) { + // Assume that we are killing an executor that was started by default and + // not through the request api + requestedTotalExecutorsPerResourceProfile(rp) = 0 + } else { + val requestedTotalForRp = requestedTotalExecutorsPerResourceProfile(rp) + requestedTotalExecutorsPerResourceProfile(rp) = math.max(requestedTotalForRp - 1, 0) + } } - doRequestTotalExecutors(requestedTotalExecutors) + doRequestTotalExecutors(requestedTotalExecutorsPerResourceProfile.toMap) } else { - numPendingExecutors += executorsToKill.size Future.successful(true) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index a9b607d8cc38c..d91d78b29f98d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -28,7 +28,7 @@ import org.apache.spark.deploy.client.{StandaloneAppClient, StandaloneAppClientL import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} -import org.apache.spark.resource.ResourceUtils +import org.apache.spark.resource.{ResourceProfile, ResourceUtils} import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -58,6 +58,7 @@ private[spark] class StandaloneSchedulerBackend( private val maxCores = conf.get(config.CORES_MAX) private val totalExpectedCores = maxCores.getOrElse(0) + private val defaultProf = sc.resourceProfileManager.defaultResourceProfile override def start(): Unit = { super.start() @@ -194,9 +195,13 @@ private[spark] class StandaloneSchedulerBackend( * * @return whether the request is acknowledged. */ - protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { + protected override def doRequestTotalExecutors( + resourceProfileToTotalExecs: Map[ResourceProfile, Int]): Future[Boolean] = { + // resources profiles not supported Option(client) match { - case Some(c) => c.requestTotalExecutors(requestedTotal) + case Some(c) => + val numExecs = resourceProfileToTotalExecs.getOrElse(defaultProf, 0) + c.requestTotalExecutors(numExecs) case None => logWarning("Attempted to request executors before driver fully initialized.") Future.successful(false) diff --git a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala index a24f1902faa31..c29546b7577fc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitor.scala @@ -70,7 +70,7 @@ private[spark] class ExecutorMonitor( // this listener. There are safeguards in other parts of the code that would prevent that executor // from being removed. private val nextTimeout = new AtomicLong(Long.MaxValue) - private var timedOutExecs = Seq.empty[String] + private var timedOutExecs = Seq.empty[(String, Int)] // Active job tracking. // @@ -100,10 +100,10 @@ private[spark] class ExecutorMonitor( } /** - * Returns the list of executors that are currently considered to be timed out. - * Should only be called from the EAM thread. + * Returns the list of executors and their ResourceProfile id that are currently considered to + * be timed out. Should only be called from the EAM thread. */ - def timedOutExecutors(): Seq[String] = { + def timedOutExecutors(): Seq[(String, Int)] = { val now = clock.nanoTime() if (now >= nextTimeout.get()) { // Temporarily set the next timeout at Long.MaxValue. This ensures that after @@ -126,7 +126,7 @@ private[spark] class ExecutorMonitor( true } } - .keys + .map { case (name, exec) => (name, exec.resourceProfileId)} .toSeq updateNextTimeout(newNextTimeout) } @@ -155,6 +155,7 @@ private[spark] class ExecutorMonitor( execResourceProfileCount.getOrDefault(id, 0) } + // for testing def getResourceProfileId(executorId: String): Int = { val execTrackingInfo = executors.get(executorId) if (execTrackingInfo != null) { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 4d89c4f079f29..53824735d2fc5 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -33,7 +33,7 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.RDDOperationScope -import org.apache.spark.resource.ResourceInformation +import org.apache.spark.resource.{ResourceInformation, ResourceProfile} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ @@ -662,7 +662,8 @@ private[spark] object JsonProtocol { val stageInfos = jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map { id => - new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown") + new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) } } SparkListenerJobStart(jobId, submissionTime, stageInfos, properties) @@ -803,7 +804,8 @@ private[spark] object JsonProtocol { } val stageInfo = new StageInfo( - stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details) + stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details, + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 8d958494d52be..8fa33f4915ea4 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_SCHEDULE_INTERVAL import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, ResourceProfileBuilder, ResourceProfileManager, TaskResourceRequests} import org.apache.spark.resource.ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo @@ -45,6 +46,9 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { private val managers = new mutable.ListBuffer[ExecutorAllocationManager]() private var listenerBus: LiveListenerBus = _ private var client: ExecutorAllocationClient = _ + private val clock = new SystemClock() + private var rpManager: ResourceProfileManager = _ + override def beforeEach(): Unit = { super.beforeEach() @@ -108,65 +112,257 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { test("starting state") { val manager = createManager(createConf()) - assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) assert(executorsPendingToRemove(manager).isEmpty) assert(addTime(manager) === ExecutorAllocationManager.NOT_SET) } - test("add executors") { + test("add executors default profile") { val manager = createManager(createConf(1, 10, 1)) post(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] + + // Keep adding until the limit is reached + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 2) + assert(numExecutorsToAddForDefaultProfile(manager) === 2) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 4) + assert(numExecutorsToAddForDefaultProfile(manager) === 4) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 4) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 8) + assert(numExecutorsToAddForDefaultProfile(manager) === 8) + // reached the limit of 10 + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + + // Register previously requested executors + onExecutorAddedDefaultProfile(manager, "first") + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + onExecutorAddedDefaultProfile(manager, "second") + onExecutorAddedDefaultProfile(manager, "third") + onExecutorAddedDefaultProfile(manager, "fourth") + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + onExecutorAddedDefaultProfile(manager, "first") // duplicates should not count + onExecutorAddedDefaultProfile(manager, "second") + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + + // Try adding again + // This should still fail because the number pending + running is still at the limit + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + } + + test("add executors multiple profiles") { + val manager = createManager(createConf(1, 10, 1)) + post(SparkListenerStageSubmitted(createStageInfo(0, 1000, rp = defaultProfile))) + val rp1 = new ResourceProfileBuilder() + val execReqs = new ExecutorResourceRequests().cores(4).resource("gpu", 4) + val taskReqs = new TaskResourceRequests().cpus(1).resource("gpu", 1) + rp1.require(execReqs).require(taskReqs) + val rprof1 = rp1.build + rpManager.addResourceProfile(rprof1) + post(SparkListenerStageSubmitted(createStageInfo(1, 1000, rp = rprof1))) + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] + // Keep adding until the limit is reached - assert(numExecutorsTarget(manager) === 1) - assert(numExecutorsToAdd(manager) === 1) - assert(addExecutors(manager) === 1) - assert(numExecutorsTarget(manager) === 2) - assert(numExecutorsToAdd(manager) === 2) - assert(addExecutors(manager) === 2) - assert(numExecutorsTarget(manager) === 4) - assert(numExecutorsToAdd(manager) === 4) - assert(addExecutors(manager) === 4) - assert(numExecutorsTarget(manager) === 8) - assert(numExecutorsToAdd(manager) === 8) - assert(addExecutors(manager) === 2) // reached the limit of 10 - assert(numExecutorsTarget(manager) === 10) - assert(numExecutorsToAdd(manager) === 1) - assert(addExecutors(manager) === 0) - assert(numExecutorsTarget(manager) === 10) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + assert(numExecutorsToAdd(manager, rprof1) === 1) + assert(numExecutorsTarget(manager, rprof1.id) === 1) + assert(addExecutorsToTarget(manager, updatesNeeded, rprof1) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 2) + assert(numExecutorsToAddForDefaultProfile(manager) === 2) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + assert(numExecutorsToAdd(manager, rprof1) === 2) + assert(numExecutorsTarget(manager, rprof1.id) === 2) + assert(addExecutorsToTarget(manager, updatesNeeded, rprof1) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 4) + assert(numExecutorsToAddForDefaultProfile(manager) === 4) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 4) + assert(numExecutorsToAdd(manager, rprof1) === 4) + assert(numExecutorsTarget(manager, rprof1.id) === 4) + assert(addExecutorsToTarget(manager, updatesNeeded, rprof1) === 4) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 8) + assert(numExecutorsToAddForDefaultProfile(manager) === 8) + // reached the limit of 10 + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + assert(numExecutorsToAdd(manager, rprof1) === 8) + assert(numExecutorsTarget(manager, rprof1.id) === 8) + assert(addExecutorsToTarget(manager, updatesNeeded, rprof1) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + assert(numExecutorsToAdd(manager, rprof1) === 1) + assert(numExecutorsTarget(manager, rprof1.id) === 10) + assert(addExecutorsToTarget(manager, updatesNeeded, rprof1) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(numExecutorsToAdd(manager, rprof1) === 1) + assert(numExecutorsTarget(manager, rprof1.id) === 10) // Register previously requested executors - onExecutorAdded(manager, "first") - assert(numExecutorsTarget(manager) === 10) - onExecutorAdded(manager, "second") - onExecutorAdded(manager, "third") - onExecutorAdded(manager, "fourth") - assert(numExecutorsTarget(manager) === 10) - onExecutorAdded(manager, "first") // duplicates should not count - onExecutorAdded(manager, "second") - assert(numExecutorsTarget(manager) === 10) + onExecutorAddedDefaultProfile(manager, "first") + onExecutorAdded(manager, "firstrp1", rprof1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsTarget(manager, rprof1.id) === 10) + onExecutorAddedDefaultProfile(manager, "second") + onExecutorAddedDefaultProfile(manager, "third") + onExecutorAddedDefaultProfile(manager, "fourth") + onExecutorAdded(manager, "secondrp1", rprof1) + onExecutorAdded(manager, "thirdrp1", rprof1) + onExecutorAdded(manager, "fourthrp1", rprof1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsTarget(manager, rprof1.id) === 10) + onExecutorAddedDefaultProfile(manager, "first") // duplicates should not count + onExecutorAddedDefaultProfile(manager, "second") + onExecutorAdded(manager, "firstrp1", rprof1) + onExecutorAdded(manager, "secondrp1", rprof1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsTarget(manager, rprof1.id) === 10) // Try adding again // This should still fail because the number pending + running is still at the limit - assert(addExecutors(manager) === 0) - assert(numExecutorsTarget(manager) === 10) - assert(numExecutorsToAdd(manager) === 1) - assert(addExecutors(manager) === 0) - assert(numExecutorsTarget(manager) === 10) - assert(numExecutorsToAdd(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + assert(addExecutorsToTarget(manager, updatesNeeded, rprof1) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(numExecutorsToAdd(manager, rprof1) === 1) + assert(numExecutorsTarget(manager, rprof1.id) === 10) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + assert(addExecutorsToTarget(manager, updatesNeeded, rprof1) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(numExecutorsToAdd(manager, rprof1) === 1) + assert(numExecutorsTarget(manager, rprof1.id) === 10) + } + + test("remove executors multiple profiles") { + val manager = createManager(createConf(5, 10, 5)) + val rp1 = new ResourceProfileBuilder() + val execReqs = new ExecutorResourceRequests().cores(4).resource("gpu", 4) + val taskReqs = new TaskResourceRequests().cpus(1).resource("gpu", 1) + rp1.require(execReqs).require(taskReqs) + val rprof1 = rp1.build + val rp2 = new ResourceProfileBuilder() + val execReqs2 = new ExecutorResourceRequests().cores(1) + val taskReqs2 = new TaskResourceRequests().cpus(1) + rp2.require(execReqs2).require(taskReqs2) + val rprof2 = rp2.build + rpManager.addResourceProfile(rprof1) + rpManager.addResourceProfile(rprof2) + post(SparkListenerStageSubmitted(createStageInfo(1, 10, rp = rprof1))) + post(SparkListenerStageSubmitted(createStageInfo(2, 10, rp = rprof2))) + + (1 to 10).map(_.toString).foreach { id => onExecutorAdded(manager, id, rprof1) } + (11 to 20).map(_.toString).foreach { id => onExecutorAdded(manager, id, rprof2) } + (21 to 30).map(_.toString).foreach { id => onExecutorAdded(manager, id, defaultProfile) } + + // Keep removing until the limit is reached + assert(executorsPendingToRemove(manager).isEmpty) + assert(removeExecutor(manager, "1", rprof1.id)) + assert(executorsPendingToRemove(manager).size === 1) + assert(executorsPendingToRemove(manager).contains("1")) + assert(removeExecutor(manager, "11", rprof2.id)) + assert(removeExecutor(manager, "2", rprof1.id)) + assert(executorsPendingToRemove(manager).size === 3) + assert(executorsPendingToRemove(manager).contains("2")) + assert(executorsPendingToRemove(manager).contains("11")) + assert(removeExecutor(manager, "21", defaultProfile.id)) + assert(removeExecutor(manager, "3", rprof1.id)) + assert(removeExecutor(manager, "4", rprof1.id)) + assert(executorsPendingToRemove(manager).size === 6) + assert(executorsPendingToRemove(manager).contains("21")) + assert(executorsPendingToRemove(manager).contains("3")) + assert(executorsPendingToRemove(manager).contains("4")) + assert(removeExecutor(manager, "5", rprof1.id)) + assert(!removeExecutor(manager, "6", rprof1.id)) // reached the limit of 5 + assert(executorsPendingToRemove(manager).size === 7) + assert(executorsPendingToRemove(manager).contains("5")) + assert(!executorsPendingToRemove(manager).contains("6")) + + // Kill executors previously requested to remove + onExecutorRemoved(manager, "1") + assert(executorsPendingToRemove(manager).size === 6) + assert(!executorsPendingToRemove(manager).contains("1")) + onExecutorRemoved(manager, "2") + onExecutorRemoved(manager, "3") + assert(executorsPendingToRemove(manager).size === 4) + assert(!executorsPendingToRemove(manager).contains("2")) + assert(!executorsPendingToRemove(manager).contains("3")) + onExecutorRemoved(manager, "2") // duplicates should not count + onExecutorRemoved(manager, "3") + assert(executorsPendingToRemove(manager).size === 4) + onExecutorRemoved(manager, "4") + onExecutorRemoved(manager, "5") + assert(executorsPendingToRemove(manager).size === 2) + assert(executorsPendingToRemove(manager).contains("11")) + assert(executorsPendingToRemove(manager).contains("21")) + + // Try removing again + // This should still fail because the number pending + running is still at the limit + assert(!removeExecutor(manager, "7", rprof1.id)) + assert(executorsPendingToRemove(manager).size === 2) + assert(!removeExecutor(manager, "8", rprof1.id)) + assert(executorsPendingToRemove(manager).size === 2) + + // make sure rprof2 has the same min limit or 5 + assert(removeExecutor(manager, "12", rprof2.id)) + assert(removeExecutor(manager, "13", rprof2.id)) + assert(removeExecutor(manager, "14", rprof2.id)) + assert(removeExecutor(manager, "15", rprof2.id)) + assert(!removeExecutor(manager, "16", rprof2.id)) // reached the limit of 5 + assert(executorsPendingToRemove(manager).size === 6) + assert(!executorsPendingToRemove(manager).contains("16")) + onExecutorRemoved(manager, "11") + onExecutorRemoved(manager, "12") + onExecutorRemoved(manager, "13") + onExecutorRemoved(manager, "14") + onExecutorRemoved(manager, "15") + assert(executorsPendingToRemove(manager).size === 1) } def testAllocationRatio(cores: Int, divisor: Double, expected: Int): Unit = { + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] val conf = createConf(3, 15) .set(config.DYN_ALLOCATION_EXECUTOR_ALLOCATION_RATIO, divisor) .set(config.EXECUTOR_CORES, cores) val manager = createManager(conf) post(SparkListenerStageSubmitted(createStageInfo(0, 20))) for (i <- 0 to 5) { - addExecutors(manager) + addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) } - assert(numExecutorsTarget(manager) === expected) + assert(numExecutorsTargetForDefaultProfileId(manager) === expected) } test("executionAllocationRatio is correctly handled") { @@ -185,127 +381,158 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { val manager = createManager(createConf(0, 10, 0)) post(SparkListenerStageSubmitted(createStageInfo(0, 5))) + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] + // Verify that we're capped at number of tasks in the stage - assert(numExecutorsTarget(manager) === 0) - assert(numExecutorsToAdd(manager) === 1) - assert(addExecutors(manager) === 1) - assert(numExecutorsTarget(manager) === 1) - assert(numExecutorsToAdd(manager) === 2) - assert(addExecutors(manager) === 2) - assert(numExecutorsTarget(manager) === 3) - assert(numExecutorsToAdd(manager) === 4) - assert(addExecutors(manager) === 2) - assert(numExecutorsTarget(manager) === 5) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) + assert(numExecutorsToAddForDefaultProfile(manager) === 2) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 3) + assert(numExecutorsToAddForDefaultProfile(manager) === 4) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) // Verify that running a task doesn't affect the target post(SparkListenerStageSubmitted(createStageInfo(1, 3))) post(SparkListenerExecutorAdded( 0L, "executor-1", new ExecutorInfo("host1", 1, Map.empty, Map.empty))) post(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) - assert(numExecutorsTarget(manager) === 5) - assert(addExecutors(manager) === 1) - assert(numExecutorsTarget(manager) === 6) - assert(numExecutorsToAdd(manager) === 2) - assert(addExecutors(manager) === 2) - assert(numExecutorsTarget(manager) === 8) - assert(numExecutorsToAdd(manager) === 4) - assert(addExecutors(manager) === 0) - assert(numExecutorsTarget(manager) === 8) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 6) + assert(numExecutorsToAddForDefaultProfile(manager) === 2) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 8) + assert(numExecutorsToAddForDefaultProfile(manager) === 4) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 8) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) // Verify that re-running a task doesn't blow things up post(SparkListenerStageSubmitted(createStageInfo(2, 3))) post(SparkListenerTaskStart(2, 0, createTaskInfo(0, 0, "executor-1"))) post(SparkListenerTaskStart(2, 0, createTaskInfo(1, 0, "executor-1"))) - assert(addExecutors(manager) === 1) - assert(numExecutorsTarget(manager) === 9) - assert(numExecutorsToAdd(manager) === 2) - assert(addExecutors(manager) === 1) - assert(numExecutorsTarget(manager) === 10) - assert(numExecutorsToAdd(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 9) + assert(numExecutorsToAddForDefaultProfile(manager) === 2) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) // Verify that running a task once we're at our limit doesn't blow things up post(SparkListenerTaskStart(2, 0, createTaskInfo(0, 1, "executor-1"))) - assert(addExecutors(manager) === 0) - assert(numExecutorsTarget(manager) === 10) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 10) } test("add executors when speculative tasks added") { val manager = createManager(createConf(0, 10, 0)) + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] + + post(SparkListenerStageSubmitted(createStageInfo(1, 2))) // Verify that we're capped at number of tasks including the speculative ones in the stage post(SparkListenerSpeculativeTaskSubmitted(1)) - assert(numExecutorsTarget(manager) === 0) - assert(numExecutorsToAdd(manager) === 1) - assert(addExecutors(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) post(SparkListenerSpeculativeTaskSubmitted(1)) post(SparkListenerSpeculativeTaskSubmitted(1)) - post(SparkListenerStageSubmitted(createStageInfo(1, 2))) - assert(numExecutorsTarget(manager) === 1) - assert(numExecutorsToAdd(manager) === 2) - assert(addExecutors(manager) === 2) - assert(numExecutorsTarget(manager) === 3) - assert(numExecutorsToAdd(manager) === 4) - assert(addExecutors(manager) === 2) - assert(numExecutorsTarget(manager) === 5) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) + assert(numExecutorsToAddForDefaultProfile(manager) === 2) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 3) + assert(numExecutorsToAddForDefaultProfile(manager) === 4) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) // Verify that running a task doesn't affect the target post(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1"))) - assert(numExecutorsTarget(manager) === 5) - assert(addExecutors(manager) === 0) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) // Verify that running a speculative task doesn't affect the target post(SparkListenerTaskStart(1, 0, createTaskInfo(1, 0, "executor-2", true))) - assert(numExecutorsTarget(manager) === 5) - assert(addExecutors(manager) === 0) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) } test("SPARK-30511 remove executors when speculative tasks end") { val clock = new ManualClock() val stage = createStageInfo(0, 40) - val manager = createManager(createConf(0, 10, 0).set(config.EXECUTOR_CORES, 4), clock = clock) + val conf = createConf(0, 10, 0).set(config.EXECUTOR_CORES, 4) + val manager = createManager(conf, clock = clock) + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] post(SparkListenerStageSubmitted(stage)) - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 2) - assert(addExecutors(manager) === 4) - assert(addExecutors(manager) === 3) - - (0 to 9).foreach(execId => onExecutorAdded(manager, execId.toString)) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 4) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 3) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + + (0 to 9).foreach(execId => onExecutorAddedDefaultProfile(manager, execId.toString)) (0 to 39).map { i => createTaskInfo(i, i, executorId = s"${i / 4}")}.foreach { info => post(SparkListenerTaskStart(0, 0, info)) } - assert(numExecutorsTarget(manager) === 10) - assert(maxNumExecutorsNeeded(manager) == 10) + assert(numExecutorsTarget(manager, defaultProfile.id) === 10) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 10) // 30 tasks (0 - 29) finished (0 to 29).map { i => createTaskInfo(i, i, executorId = s"${i / 4}")}.foreach { info => post(SparkListenerTaskEnd(0, 0, null, Success, info, new ExecutorMetrics, null)) } clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) === 3) - assert(maxNumExecutorsNeeded(manager) == 3) - (0 to 6).foreach { i => assert(removeExecutor(manager, i.toString))} + assert(numExecutorsTarget(manager, defaultProfile.id) === 3) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 3) + (0 to 6).foreach { i => assert(removeExecutorDefaultProfile(manager, i.toString))} (0 to 6).foreach { i => onExecutorRemoved(manager, i.toString)} // 10 speculative tasks (30 - 39) launch for the remaining tasks (30 to 39).foreach { _ => post(SparkListenerSpeculativeTaskSubmitted(0))} - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 1) - assert(numExecutorsTarget(manager) == 5) - assert(maxNumExecutorsNeeded(manager) == 5) - (10 to 12).foreach(execId => onExecutorAdded(manager, execId.toString)) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTarget(manager, defaultProfile.id) == 5) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 5) + (10 to 12).foreach(execId => onExecutorAddedDefaultProfile(manager, execId.toString)) (40 to 49).map { i => createTaskInfo(taskId = i, taskIndex = i - 10, executorId = s"${i / 4}", speculative = true)} .foreach { info => post(SparkListenerTaskStart(0, 0, info))} clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) == 5) // At this point, we still have 6 executors running - assert(maxNumExecutorsNeeded(manager) == 5) + // At this point, we still have 6 executors running + assert(numExecutorsTarget(manager, defaultProfile.id) == 5) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 5) // 6 speculative tasks (40 - 45) finish before the original tasks, with 4 speculative remaining (40 to 45).map { i => @@ -314,9 +541,9 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { info => post(SparkListenerTaskEnd(0, 0, null, Success, info, new ExecutorMetrics, null))} clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) === 4) - assert(maxNumExecutorsNeeded(manager) == 4) - assert(removeExecutor(manager, "10")) + assert(numExecutorsTarget(manager, defaultProfile.id) === 4) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 4) + assert(removeExecutorDefaultProfile(manager, "10")) onExecutorRemoved(manager, "10") // At this point, we still have 5 executors running: ["7", "8", "9", "11", "12"] @@ -327,9 +554,9 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { SparkListenerTaskEnd(0, 0, null, TaskKilled("test"), info, new ExecutorMetrics, null))} clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) === 2) - assert(maxNumExecutorsNeeded(manager) == 2) - (7 to 8).foreach { i => assert(removeExecutor(manager, i.toString))} + assert(numExecutorsTarget(manager, defaultProfile.id) === 2) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 2) + (7 to 8).foreach { i => assert(removeExecutorDefaultProfile(manager, i.toString))} (7 to 8).foreach { i => onExecutorRemoved(manager, i.toString)} // At this point, we still have 3 executors running: ["9", "11", "12"] @@ -343,8 +570,8 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { // tasks running. Target lowers to 2, but still hold 3 executors ["9", "11", "12"] clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) === 2) - assert(maxNumExecutorsNeeded(manager) == 2) + assert(numExecutorsTarget(manager, defaultProfile.id) === 2) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 2) // At this point, we still have 3 executors running: ["9", "11", "12"] // Task 37 and 47 succeed at the same time @@ -357,9 +584,9 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { // tasks running clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) === 1) - assert(maxNumExecutorsNeeded(manager) == 1) - assert(removeExecutor(manager, "11")) + assert(numExecutorsTarget(manager, defaultProfile.id) === 1) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 1) + assert(removeExecutorDefaultProfile(manager, "11")) onExecutorRemoved(manager, "11") // At this point, we still have 2 executors running: ["9", "12"] @@ -372,14 +599,14 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) // maxNeeded = 1, allocate one more to satisfy speculation locality requirement - assert(numExecutorsTarget(manager) === 2) - assert(maxNumExecutorsNeeded(manager) == 2) + assert(numExecutorsTarget(manager, defaultProfile.id) === 2) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 2) post(SparkListenerTaskStart(0, 0, createTaskInfo(50, 39, executorId = "12", speculative = true))) clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) === 1) - assert(maxNumExecutorsNeeded(manager) == 1) + assert(numExecutorsTarget(manager, defaultProfile.id) === 1) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 1) // Task 39 and 48 succeed, task 50 killed post(SparkListenerTaskEnd(0, 0, null, Success, @@ -391,11 +618,11 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { post(SparkListenerStageCompleted(stage)) clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) === 0) - assert(maxNumExecutorsNeeded(manager) == 0) - assert(removeExecutor(manager, "9")) + assert(numExecutorsTarget(manager, defaultProfile.id) === 0) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 0) + assert(removeExecutorDefaultProfile(manager, "9")) onExecutorRemoved(manager, "9") - assert(removeExecutor(manager, "12")) + assert(removeExecutorDefaultProfile(manager, "12")) onExecutorRemoved(manager, "12") } @@ -417,43 +644,49 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { post(SparkListenerStageCompleted(stage)) // There are still two tasks that belong to the zombie stage running. - assert(totalRunningTasks(manager) === 2) + assert(totalRunningTasksPerResourceProfile(manager) === 2) // submit another attempt for the stage. We count completions from the first zombie attempt val stageAttempt1 = createStageInfo(stage.stageId, 5, attemptId = 1) post(SparkListenerStageSubmitted(stageAttempt1)) post(SparkListenerTaskEnd(0, 0, null, Success, taskInfo1, new ExecutorMetrics, null)) - assert(totalRunningTasks(manager) === 1) + assert(totalRunningTasksPerResourceProfile(manager) === 1) val attemptTaskInfo1 = createTaskInfo(3, 0, "executor-1") val attemptTaskInfo2 = createTaskInfo(4, 1, "executor-1") post(SparkListenerTaskStart(0, 1, attemptTaskInfo1)) post(SparkListenerTaskStart(0, 1, attemptTaskInfo2)) - assert(totalRunningTasks(manager) === 3) + assert(totalRunningTasksPerResourceProfile(manager) === 3) post(SparkListenerTaskEnd(0, 1, null, Success, attemptTaskInfo1, new ExecutorMetrics, null)) - assert(totalRunningTasks(manager) === 2) + assert(totalRunningTasksPerResourceProfile(manager) === 2) post(SparkListenerTaskEnd(0, 0, null, Success, taskInfo2, new ExecutorMetrics, null)) - assert(totalRunningTasks(manager) === 1) + assert(totalRunningTasksPerResourceProfile(manager) === 1) post(SparkListenerTaskEnd(0, 1, null, Success, attemptTaskInfo2, new ExecutorMetrics, null)) - assert(totalRunningTasks(manager) === 0) + assert(totalRunningTasksPerResourceProfile(manager) === 0) } testRetry("cancel pending executors when no longer needed") { val manager = createManager(createConf(0, 10, 0)) post(SparkListenerStageSubmitted(createStageInfo(2, 5))) - assert(numExecutorsTarget(manager) === 0) - assert(numExecutorsToAdd(manager) === 1) - assert(addExecutors(manager) === 1) - assert(numExecutorsTarget(manager) === 1) - assert(numExecutorsToAdd(manager) === 2) - assert(addExecutors(manager) === 2) - assert(numExecutorsTarget(manager) === 3) + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] + + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) + assert(numExecutorsToAddForDefaultProfile(manager) === 2) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 3) val task1Info = createTaskInfo(0, 0, "executor-1") post(SparkListenerTaskStart(2, 0, task1Info)) - assert(numExecutorsToAdd(manager) === 4) - assert(addExecutors(manager) === 2) + assert(numExecutorsToAddForDefaultProfile(manager) === 4) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) val task2Info = createTaskInfo(1, 0, "executor-1") post(SparkListenerTaskStart(2, 0, task2Info)) @@ -469,22 +702,21 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { test("remove executors") { val manager = createManager(createConf(5, 10, 5)) - (1 to 10).map(_.toString).foreach { id => onExecutorAdded(manager, id) } + (1 to 10).map(_.toString).foreach { id => onExecutorAddedDefaultProfile(manager, id) } // Keep removing until the limit is reached assert(executorsPendingToRemove(manager).isEmpty) - assert(removeExecutor(manager, "1")) + assert(removeExecutorDefaultProfile(manager, "1")) assert(executorsPendingToRemove(manager).size === 1) assert(executorsPendingToRemove(manager).contains("1")) - assert(removeExecutor(manager, "2")) - assert(removeExecutor(manager, "3")) + assert(removeExecutorDefaultProfile(manager, "2")) + assert(removeExecutorDefaultProfile(manager, "3")) assert(executorsPendingToRemove(manager).size === 3) assert(executorsPendingToRemove(manager).contains("2")) assert(executorsPendingToRemove(manager).contains("3")) - assert(executorsPendingToRemove(manager).size === 3) - assert(removeExecutor(manager, "4")) - assert(removeExecutor(manager, "5")) - assert(!removeExecutor(manager, "6")) // reached the limit of 5 + assert(removeExecutorDefaultProfile(manager, "4")) + assert(removeExecutorDefaultProfile(manager, "5")) + assert(!removeExecutorDefaultProfile(manager, "6")) // reached the limit of 5 assert(executorsPendingToRemove(manager).size === 5) assert(executorsPendingToRemove(manager).contains("4")) assert(executorsPendingToRemove(manager).contains("5")) @@ -508,29 +740,29 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { // Try removing again // This should still fail because the number pending + running is still at the limit - assert(!removeExecutor(manager, "7")) + assert(!removeExecutorDefaultProfile(manager, "7")) assert(executorsPendingToRemove(manager).isEmpty) - assert(!removeExecutor(manager, "8")) + assert(!removeExecutorDefaultProfile(manager, "8")) assert(executorsPendingToRemove(manager).isEmpty) } test("remove multiple executors") { val manager = createManager(createConf(5, 10, 5)) - (1 to 10).map(_.toString).foreach { id => onExecutorAdded(manager, id) } + (1 to 10).map(_.toString).foreach { id => onExecutorAddedDefaultProfile(manager, id) } // Keep removing until the limit is reached assert(executorsPendingToRemove(manager).isEmpty) - assert(removeExecutors(manager, Seq("1")) === Seq("1")) + assert(removeExecutorsDefaultProfile(manager, Seq("1")) === Seq("1")) assert(executorsPendingToRemove(manager).size === 1) assert(executorsPendingToRemove(manager).contains("1")) - assert(removeExecutors(manager, Seq("2", "3")) === Seq("2", "3")) + assert(removeExecutorsDefaultProfile(manager, Seq("2", "3")) === Seq("2", "3")) assert(executorsPendingToRemove(manager).size === 3) assert(executorsPendingToRemove(manager).contains("2")) assert(executorsPendingToRemove(manager).contains("3")) assert(executorsPendingToRemove(manager).size === 3) - assert(removeExecutor(manager, "4")) - assert(removeExecutors(manager, Seq("5")) === Seq("5")) - assert(!removeExecutor(manager, "6")) // reached the limit of 5 + assert(removeExecutorDefaultProfile(manager, "4")) + assert(removeExecutorsDefaultProfile(manager, Seq("5")) === Seq("5")) + assert(!removeExecutorDefaultProfile(manager, "6")) // reached the limit of 5 assert(executorsPendingToRemove(manager).size === 5) assert(executorsPendingToRemove(manager).contains("4")) assert(executorsPendingToRemove(manager).contains("5")) @@ -554,87 +786,100 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { // Try removing again // This should still fail because the number pending + running is still at the limit - assert(!removeExecutor(manager, "7")) + assert(!removeExecutorDefaultProfile(manager, "7")) assert(executorsPendingToRemove(manager).isEmpty) - assert(removeExecutors(manager, Seq("8")) !== Seq("8")) + assert(removeExecutorsDefaultProfile(manager, Seq("8")) !== Seq("8")) assert(executorsPendingToRemove(manager).isEmpty) } - test ("Removing with various numExecutorsTarget condition") { + test ("Removing with various numExecutorsTargetForDefaultProfileId condition") { val manager = createManager(createConf(5, 12, 5)) post(SparkListenerStageSubmitted(createStageInfo(0, 8))) - // Remove when numExecutorsTarget is the same as the current number of executors - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 2) - (1 to 8).foreach(execId => onExecutorAdded(manager, execId.toString)) + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] + + // Remove when numExecutorsTargetForDefaultProfileId is the same as the current + // number of executors + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + (1 to 8).foreach(execId => onExecutorAddedDefaultProfile(manager, execId.toString)) (1 to 8).map { i => createTaskInfo(i, i, s"$i") }.foreach { info => post(SparkListenerTaskStart(0, 0, info)) } assert(manager.executorMonitor.executorCount === 8) - assert(numExecutorsTarget(manager) === 8) - assert(maxNumExecutorsNeeded(manager) == 8) - assert(!removeExecutor(manager, "1")) // won't work since numExecutorsTarget == numExecutors + assert(numExecutorsTargetForDefaultProfileId(manager) === 8) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 8) + // won't work since numExecutorsTargetForDefaultProfileId == numExecutors + assert(!removeExecutorDefaultProfile(manager, "1")) - // Remove executors when numExecutorsTarget is lower than current number of executors + // Remove executors when numExecutorsTargetForDefaultProfileId is lower than + // current number of executors (1 to 3).map { i => createTaskInfo(i, i, s"$i") }.foreach { info => post(SparkListenerTaskEnd(0, 0, null, Success, info, new ExecutorMetrics, null)) } adjustRequestedExecutors(manager) assert(manager.executorMonitor.executorCount === 8) - assert(numExecutorsTarget(manager) === 5) - assert(maxNumExecutorsNeeded(manager) == 5) - assert(removeExecutor(manager, "1")) - assert(removeExecutors(manager, Seq("2", "3"))=== Seq("2", "3")) + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 5) + assert(removeExecutorDefaultProfile(manager, "1")) + assert(removeExecutorsDefaultProfile(manager, Seq("2", "3"))=== Seq("2", "3")) onExecutorRemoved(manager, "1") onExecutorRemoved(manager, "2") onExecutorRemoved(manager, "3") - // numExecutorsTarget is lower than minNumExecutors + // numExecutorsTargetForDefaultProfileId is lower than minNumExecutors post(SparkListenerTaskEnd(0, 0, null, Success, createTaskInfo(4, 4, "4"), new ExecutorMetrics, null)) assert(manager.executorMonitor.executorCount === 5) - assert(numExecutorsTarget(manager) === 5) - assert(maxNumExecutorsNeeded(manager) == 4) - assert(!removeExecutor(manager, "4")) // lower limit - assert(addExecutors(manager) === 0) // upper limit + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) == 4) + assert(!removeExecutorDefaultProfile(manager, "4")) // lower limit + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) // upper limit } test ("interleaving add and remove") { val manager = createManager(createConf(5, 12, 5)) post(SparkListenerStageSubmitted(createStageInfo(0, 1000))) + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] + // Add a few executors - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 2) - onExecutorAdded(manager, "1") - onExecutorAdded(manager, "2") - onExecutorAdded(manager, "3") - onExecutorAdded(manager, "4") - onExecutorAdded(manager, "5") - onExecutorAdded(manager, "6") - onExecutorAdded(manager, "7") - onExecutorAdded(manager, "8") + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + onExecutorAddedDefaultProfile(manager, "1") + onExecutorAddedDefaultProfile(manager, "2") + onExecutorAddedDefaultProfile(manager, "3") + onExecutorAddedDefaultProfile(manager, "4") + onExecutorAddedDefaultProfile(manager, "5") + onExecutorAddedDefaultProfile(manager, "6") + onExecutorAddedDefaultProfile(manager, "7") + onExecutorAddedDefaultProfile(manager, "8") assert(manager.executorMonitor.executorCount === 8) - assert(numExecutorsTarget(manager) === 8) + assert(numExecutorsTargetForDefaultProfileId(manager) === 8) // Remove when numTargetExecutors is equal to the current number of executors - assert(!removeExecutor(manager, "1")) - assert(removeExecutors(manager, Seq("2", "3")) !== Seq("2", "3")) + assert(!removeExecutorDefaultProfile(manager, "1")) + assert(removeExecutorsDefaultProfile(manager, Seq("2", "3")) !== Seq("2", "3")) // Remove until limit - onExecutorAdded(manager, "9") - onExecutorAdded(manager, "10") - onExecutorAdded(manager, "11") - onExecutorAdded(manager, "12") + onExecutorAddedDefaultProfile(manager, "9") + onExecutorAddedDefaultProfile(manager, "10") + onExecutorAddedDefaultProfile(manager, "11") + onExecutorAddedDefaultProfile(manager, "12") assert(manager.executorMonitor.executorCount === 12) - assert(numExecutorsTarget(manager) === 8) + assert(numExecutorsTargetForDefaultProfileId(manager) === 8) - assert(removeExecutor(manager, "1")) - assert(removeExecutors(manager, Seq("2", "3", "4")) === Seq("2", "3", "4")) - assert(!removeExecutor(manager, "5")) // lower limit reached - assert(!removeExecutor(manager, "6")) + assert(removeExecutorDefaultProfile(manager, "1")) + assert(removeExecutorsDefaultProfile(manager, Seq("2", "3", "4")) === Seq("2", "3", "4")) + assert(!removeExecutorDefaultProfile(manager, "5")) // lower limit reached + assert(!removeExecutorDefaultProfile(manager, "6")) onExecutorRemoved(manager, "1") onExecutorRemoved(manager, "2") onExecutorRemoved(manager, "3") @@ -642,33 +887,36 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { assert(manager.executorMonitor.executorCount === 8) // Add until limit - assert(!removeExecutor(manager, "7")) // still at lower limit + assert(!removeExecutorDefaultProfile(manager, "7")) // still at lower limit assert((manager, Seq("8")) !== Seq("8")) - onExecutorAdded(manager, "13") - onExecutorAdded(manager, "14") - onExecutorAdded(manager, "15") - onExecutorAdded(manager, "16") + onExecutorAddedDefaultProfile(manager, "13") + onExecutorAddedDefaultProfile(manager, "14") + onExecutorAddedDefaultProfile(manager, "15") + onExecutorAddedDefaultProfile(manager, "16") assert(manager.executorMonitor.executorCount === 12) // Remove succeeds again, now that we are no longer at the lower limit - assert(removeExecutors(manager, Seq("5", "6", "7")) === Seq("5", "6", "7")) - assert(removeExecutor(manager, "8")) + assert(removeExecutorsDefaultProfile(manager, Seq("5", "6", "7")) === Seq("5", "6", "7")) + assert(removeExecutorDefaultProfile(manager, "8")) assert(manager.executorMonitor.executorCount === 12) onExecutorRemoved(manager, "5") onExecutorRemoved(manager, "6") assert(manager.executorMonitor.executorCount === 10) - assert(numExecutorsToAdd(manager) === 4) + assert(numExecutorsToAddForDefaultProfile(manager) === 4) onExecutorRemoved(manager, "9") onExecutorRemoved(manager, "10") - assert(addExecutors(manager) === 4) // at upper limit - onExecutorAdded(manager, "17") - onExecutorAdded(manager, "18") + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 4) // at upper limit + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + onExecutorAddedDefaultProfile(manager, "17") + onExecutorAddedDefaultProfile(manager, "18") assert(manager.executorMonitor.executorCount === 10) - assert(addExecutors(manager) === 0) // still at upper limit - onExecutorAdded(manager, "19") - onExecutorAdded(manager, "20") + // still at upper limit + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 0) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + onExecutorAddedDefaultProfile(manager, "19") + onExecutorAddedDefaultProfile(manager, "20") assert(manager.executorMonitor.executorCount === 12) - assert(numExecutorsTarget(manager) === 12) + assert(numExecutorsTargetForDefaultProfileId(manager) === 12) } test("starting/canceling add timer") { @@ -706,22 +954,22 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { val manager = createManager(createConf(0, 20, 0), clock = clock) // No events - we should not be adding or removing - assert(numExecutorsTarget(manager) === 0) + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) schedule(manager) - assert(numExecutorsTarget(manager) === 0) + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) clock.advance(100L) schedule(manager) - assert(numExecutorsTarget(manager) === 0) + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) clock.advance(1000L) schedule(manager) - assert(numExecutorsTarget(manager) === 0) + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) clock.advance(10000L) schedule(manager) - assert(numExecutorsTarget(manager) === 0) + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) assert(executorsPendingToRemove(manager).isEmpty) } @@ -734,43 +982,43 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { onSchedulerBacklogged(manager) clock.advance(schedulerBacklogTimeout * 1000 / 2) schedule(manager) - assert(numExecutorsTarget(manager) === 0) // timer not exceeded yet + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) // timer not exceeded yet clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsTarget(manager) === 1) // first timer exceeded + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) // first timer exceeded clock.advance(sustainedSchedulerBacklogTimeout * 1000 / 2) schedule(manager) - assert(numExecutorsTarget(manager) === 1) // second timer not exceeded yet + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) // second timer not exceeded yet clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsTarget(manager) === 1 + 2) // second timer exceeded + assert(numExecutorsTargetForDefaultProfileId(manager) === 1 + 2) // second timer exceeded clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsTarget(manager) === 1 + 2 + 4) // third timer exceeded + assert(numExecutorsTargetForDefaultProfileId(manager) === 1 + 2 + 4) // third timer exceeded // Scheduler queue drained onSchedulerQueueEmpty(manager) clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsTarget(manager) === 7) // timer is canceled + assert(numExecutorsTargetForDefaultProfileId(manager) === 7) // timer is canceled clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsTarget(manager) === 7) + assert(numExecutorsTargetForDefaultProfileId(manager) === 7) // Scheduler queue backlogged again onSchedulerBacklogged(manager) clock.advance(schedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsTarget(manager) === 7 + 1) // timer restarted + assert(numExecutorsTargetForDefaultProfileId(manager) === 7 + 1) // timer restarted clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsTarget(manager) === 7 + 1 + 2) + assert(numExecutorsTargetForDefaultProfileId(manager) === 7 + 1 + 2) clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsTarget(manager) === 7 + 1 + 2 + 4) + assert(numExecutorsTargetForDefaultProfileId(manager) === 7 + 1 + 2 + 4) clock.advance(sustainedSchedulerBacklogTimeout * 1000) schedule(manager) - assert(numExecutorsTarget(manager) === 20) // limit reached + assert(numExecutorsTargetForDefaultProfileId(manager) === 20) // limit reached } test("mock polling loop remove behavior") { @@ -778,9 +1026,9 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { val manager = createManager(createConf(1, 20, 1), clock = clock) // Remove idle executors on timeout - onExecutorAdded(manager, "executor-1") - onExecutorAdded(manager, "executor-2") - onExecutorAdded(manager, "executor-3") + onExecutorAddedDefaultProfile(manager, "executor-1") + onExecutorAddedDefaultProfile(manager, "executor-2") + onExecutorAddedDefaultProfile(manager, "executor-3") assert(executorsPendingToRemove(manager).isEmpty) // idle threshold not reached yet @@ -796,10 +1044,10 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { assert(executorsPendingToRemove(manager).size === 2) // limit reached (1 executor remaining) // Mark a subset as busy - only idle executors should be removed - onExecutorAdded(manager, "executor-4") - onExecutorAdded(manager, "executor-5") - onExecutorAdded(manager, "executor-6") - onExecutorAdded(manager, "executor-7") + onExecutorAddedDefaultProfile(manager, "executor-4") + onExecutorAddedDefaultProfile(manager, "executor-5") + onExecutorAddedDefaultProfile(manager, "executor-6") + onExecutorAddedDefaultProfile(manager, "executor-7") assert(manager.executorMonitor.executorCount === 7) assert(executorsPendingToRemove(manager).size === 2) // 2 pending to be removed onExecutorBusy(manager, "executor-4") @@ -864,23 +1112,31 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { val stage1 = createStageInfo(0, 1000) post(SparkListenerStageSubmitted(stage1)) - assert(addExecutors(manager) === 1) - assert(addExecutors(manager) === 2) - assert(addExecutors(manager) === 4) - assert(addExecutors(manager) === 8) - assert(numExecutorsTarget(manager) === 15) + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] + + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 4) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 8) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 15) (0 until 15).foreach { i => - onExecutorAdded(manager, s"executor-$i") + onExecutorAddedDefaultProfile(manager, s"executor-$i") } assert(manager.executorMonitor.executorCount === 15) post(SparkListenerStageCompleted(stage1)) adjustRequestedExecutors(manager) - assert(numExecutorsTarget(manager) === 0) + assert(numExecutorsTargetForDefaultProfileId(manager) === 0) post(SparkListenerStageSubmitted(createStageInfo(1, 1000))) - addExecutors(manager) - assert(numExecutorsTarget(manager) === 16) + addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 16) } test("avoid ramp down initial executors until first job is submitted") { @@ -888,19 +1144,19 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { val manager = createManager(createConf(2, 5, 3), clock = clock) // Verify the initial number of executors - assert(numExecutorsTarget(manager) === 3) + assert(numExecutorsTargetForDefaultProfileId(manager) === 3) schedule(manager) // Verify whether the initial number of executors is kept with no pending tasks - assert(numExecutorsTarget(manager) === 3) + assert(numExecutorsTargetForDefaultProfileId(manager) === 3) post(SparkListenerStageSubmitted(createStageInfo(1, 2))) clock.advance(100L) - assert(maxNumExecutorsNeeded(manager) === 2) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) === 2) schedule(manager) // Verify that current number of executors should be ramp down when first job is submitted - assert(numExecutorsTarget(manager) === 2) + assert(numExecutorsTargetForDefaultProfileId(manager) === 2) } test("avoid ramp down initial executors until idle executor is timeout") { @@ -908,20 +1164,20 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { val manager = createManager(createConf(2, 5, 3), clock = clock) // Verify the initial number of executors - assert(numExecutorsTarget(manager) === 3) + assert(numExecutorsTargetForDefaultProfileId(manager) === 3) schedule(manager) // Verify the initial number of executors is kept when no pending tasks - assert(numExecutorsTarget(manager) === 3) + assert(numExecutorsTargetForDefaultProfileId(manager) === 3) (0 until 3).foreach { i => - onExecutorAdded(manager, s"executor-$i") + onExecutorAddedDefaultProfile(manager, s"executor-$i") } clock.advance(executorIdleTimeout * 1000) - assert(maxNumExecutorsNeeded(manager) === 0) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) === 0) schedule(manager) - // Verify executor is timeout,numExecutorsTarget is recalculated - assert(numExecutorsTarget(manager) === 2) + // Verify executor is timeout,numExecutorsTargetForDefaultProfileId is recalculated + assert(numExecutorsTargetForDefaultProfileId(manager) === 2) } test("get pending task number and related locality preference") { @@ -937,7 +1193,8 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { val stageInfo1 = createStageInfo(1, 5, localityPreferences1) post(SparkListenerStageSubmitted(stageInfo1)) - assert(localityAwareTasks(manager) === 3) + assert(localityAwareTasksForDefaultProfile(manager) === 3) + val hostToLocal = hostToLocalTaskCount(manager) assert(hostToLocalTaskCount(manager) === Map("host1" -> 2, "host2" -> 3, "host3" -> 2, "host4" -> 2)) @@ -949,67 +1206,76 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { val stageInfo2 = createStageInfo(2, 3, localityPreferences2) post(SparkListenerStageSubmitted(stageInfo2)) - assert(localityAwareTasks(manager) === 5) + assert(localityAwareTasksForDefaultProfile(manager) === 5) assert(hostToLocalTaskCount(manager) === Map("host1" -> 2, "host2" -> 4, "host3" -> 4, "host4" -> 3, "host5" -> 2)) post(SparkListenerStageCompleted(stageInfo1)) - assert(localityAwareTasks(manager) === 2) + assert(localityAwareTasksForDefaultProfile(manager) === 2) assert(hostToLocalTaskCount(manager) === Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2)) } - test("SPARK-8366: maxNumExecutorsNeeded should properly handle failed tasks") { + test("SPARK-8366: maxNumExecutorsNeededPerResourceProfile should properly handle failed tasks") { val manager = createManager(createConf()) - assert(maxNumExecutorsNeeded(manager) === 0) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) === 0) post(SparkListenerStageSubmitted(createStageInfo(0, 1))) - assert(maxNumExecutorsNeeded(manager) === 1) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) === 1) val taskInfo = createTaskInfo(1, 1, "executor-1") post(SparkListenerTaskStart(0, 0, taskInfo)) - assert(maxNumExecutorsNeeded(manager) === 1) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) === 1) // If the task is failed, we expect it to be resubmitted later. val taskEndReason = ExceptionFailure(null, null, null, null, None) post(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, new ExecutorMetrics, null)) - assert(maxNumExecutorsNeeded(manager) === 1) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) === 1) } test("reset the state of allocation manager") { val manager = createManager(createConf()) - assert(numExecutorsTarget(manager) === 1) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) + + val updatesNeeded = + new mutable.HashMap[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates] // Allocation manager is reset when adding executor requests are sent without reporting back // executor added. post(SparkListenerStageSubmitted(createStageInfo(0, 10))) - assert(addExecutors(manager) === 1) - assert(numExecutorsTarget(manager) === 2) - assert(addExecutors(manager) === 2) - assert(numExecutorsTarget(manager) === 4) - assert(addExecutors(manager) === 1) - assert(numExecutorsTarget(manager) === 5) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 2) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 2) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 4) + assert(addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) === 1) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) manager.reset() - assert(numExecutorsTarget(manager) === 1) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) assert(manager.executorMonitor.executorCount === 0) // Allocation manager is reset when executors are added. post(SparkListenerStageSubmitted(createStageInfo(0, 10))) - addExecutors(manager) - addExecutors(manager) - addExecutors(manager) - assert(numExecutorsTarget(manager) === 5) - - onExecutorAdded(manager, "first") - onExecutorAdded(manager, "second") - onExecutorAdded(manager, "third") - onExecutorAdded(manager, "fourth") - onExecutorAdded(manager, "fifth") + addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) + + onExecutorAddedDefaultProfile(manager, "first") + onExecutorAddedDefaultProfile(manager, "second") + onExecutorAddedDefaultProfile(manager, "third") + onExecutorAddedDefaultProfile(manager, "fourth") + onExecutorAddedDefaultProfile(manager, "fifth") assert(manager.executorMonitor.executorCount === 5) // Cluster manager lost will make all the live executors lost, so here simulate this behavior @@ -1020,28 +1286,31 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { onExecutorRemoved(manager, "fifth") manager.reset() - assert(numExecutorsTarget(manager) === 1) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) assert(manager.executorMonitor.executorCount === 0) // Allocation manager is reset when executors are pending to remove - addExecutors(manager) - addExecutors(manager) - addExecutors(manager) - assert(numExecutorsTarget(manager) === 5) - - onExecutorAdded(manager, "first") - onExecutorAdded(manager, "second") - onExecutorAdded(manager, "third") - onExecutorAdded(manager, "fourth") - onExecutorAdded(manager, "fifth") - onExecutorAdded(manager, "sixth") - onExecutorAdded(manager, "seventh") - onExecutorAdded(manager, "eighth") + addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + addExecutorsToTargetForDefaultProfile(manager, updatesNeeded) + doUpdateRequest(manager, updatesNeeded.toMap, clock.getTimeMillis()) + assert(numExecutorsTargetForDefaultProfileId(manager) === 5) + + onExecutorAddedDefaultProfile(manager, "first") + onExecutorAddedDefaultProfile(manager, "second") + onExecutorAddedDefaultProfile(manager, "third") + onExecutorAddedDefaultProfile(manager, "fourth") + onExecutorAddedDefaultProfile(manager, "fifth") + onExecutorAddedDefaultProfile(manager, "sixth") + onExecutorAddedDefaultProfile(manager, "seventh") + onExecutorAddedDefaultProfile(manager, "eighth") assert(manager.executorMonitor.executorCount === 8) - removeExecutor(manager, "first") - removeExecutors(manager, Seq("second", "third")) + removeExecutorDefaultProfile(manager, "first") + removeExecutorsDefaultProfile(manager, Seq("second", "third")) assert(executorsPendingToRemove(manager) === Set("first", "second", "third")) assert(manager.executorMonitor.executorCount === 8) @@ -1055,8 +1324,8 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { manager.reset() - assert(numExecutorsTarget(manager) === 1) - assert(numExecutorsToAdd(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) + assert(numExecutorsToAddForDefaultProfile(manager) === 1) assert(executorsPendingToRemove(manager) === Set.empty) assert(manager.executorMonitor.executorCount === 0) } @@ -1067,31 +1336,31 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { createConf(1, 2, 1).set(config.DYN_ALLOCATION_TESTING, false), clock = clock) - when(client.requestTotalExecutors(meq(2), any(), any())).thenReturn(true) + when(client.requestTotalExecutors(any(), any(), any())).thenReturn(true) // test setup -- job with 2 tasks, scale up to two executors - assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) post(SparkListenerExecutorAdded( clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty, Map.empty))) post(SparkListenerStageSubmitted(createStageInfo(0, 2))) clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) === 2) + assert(numExecutorsTargetForDefaultProfileId(manager) === 2) val taskInfo0 = createTaskInfo(0, 0, "executor-1") post(SparkListenerTaskStart(0, 0, taskInfo0)) post(SparkListenerExecutorAdded( clock.getTimeMillis(), "executor-2", new ExecutorInfo("host1", 1, Map.empty, Map.empty))) val taskInfo1 = createTaskInfo(1, 1, "executor-2") post(SparkListenerTaskStart(0, 0, taskInfo1)) - assert(numExecutorsTarget(manager) === 2) + assert(numExecutorsTargetForDefaultProfileId(manager) === 2) // have one task finish -- we should adjust the target number of executors down // but we should *not* kill any executors yet post(SparkListenerTaskEnd(0, 0, null, Success, taskInfo0, new ExecutorMetrics, null)) - assert(maxNumExecutorsNeeded(manager) === 1) - assert(numExecutorsTarget(manager) === 2) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 2) clock.advance(1000) manager invokePrivate _updateAndSyncNumExecutorsTarget(clock.nanoTime()) - assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) verify(client, never).killExecutors(any(), any(), any(), any()) // now we cross the idle timeout for executor-1, so we kill it. the really important @@ -1101,8 +1370,8 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { .thenReturn(Seq("executor-1")) clock.advance(3000) schedule(manager) - assert(maxNumExecutorsNeeded(manager) === 1) - assert(numExecutorsTarget(manager) === 1) + assert(maxNumExecutorsNeededPerResourceProfile(manager, defaultProfile) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) // here's the important verify -- we did kill the executors, but did not adjust the target count verify(client).killExecutors(Seq("executor-1"), false, false, false) } @@ -1110,7 +1379,7 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { test("SPARK-26758 check executor target number after idle time out ") { val clock = new ManualClock(10000L) val manager = createManager(createConf(1, 5, 3), clock = clock) - assert(numExecutorsTarget(manager) === 3) + assert(numExecutorsTargetForDefaultProfileId(manager) === 3) post(SparkListenerExecutorAdded( clock.getTimeMillis(), "executor-1", new ExecutorInfo("host1", 1, Map.empty))) post(SparkListenerExecutorAdded( @@ -1121,14 +1390,14 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { clock.advance(executorIdleTimeout * 1000) schedule(manager) // once the schedule is run target executor number should be 1 - assert(numExecutorsTarget(manager) === 1) + assert(numExecutorsTargetForDefaultProfileId(manager) === 1) } private def createConf( minExecutors: Int = 1, maxExecutors: Int = 5, initialExecutors: Int = 1): SparkConf = { - new SparkConf() + val sparkConf = new SparkConf() .set(config.DYN_ALLOCATION_ENABLED, true) .set(config.DYN_ALLOCATION_MIN_EXECUTORS, minExecutors) .set(config.DYN_ALLOCATION_MAX_EXECUTORS, maxExecutors) @@ -1143,12 +1412,16 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { // SPARK-22864: effectively disable the allocation schedule by setting the period to a // really long value. .set(TEST_SCHEDULE_INTERVAL, 10000L) + sparkConf } private def createManager( conf: SparkConf, clock: Clock = new SystemClock()): ExecutorAllocationManager = { - val manager = new ExecutorAllocationManager(client, listenerBus, conf, clock = clock) + ResourceProfile.reInitDefaultProfile(conf) + rpManager = new ResourceProfileManager(conf) + val manager = new ExecutorAllocationManager(client, listenerBus, conf, clock = clock, + resourceProfileManager = rpManager) managers += manager manager.start() manager @@ -1157,7 +1430,18 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { private val execInfo = new ExecutorInfo("host1", 1, Map.empty, Map.empty, Map.empty, DEFAULT_RESOURCE_PROFILE_ID) - private def onExecutorAdded(manager: ExecutorAllocationManager, id: String): Unit = { + private def onExecutorAddedDefaultProfile( + manager: ExecutorAllocationManager, + id: String): Unit = { + post(SparkListenerExecutorAdded(0L, id, execInfo)) + } + + private def onExecutorAdded( + manager: ExecutorAllocationManager, + id: String, + rp: ResourceProfile): Unit = { + val cores = rp.getExecutorCores.getOrElse(1) + val execInfo = new ExecutorInfo("host1", cores, Map.empty, Map.empty, Map.empty, rp.id) post(SparkListenerExecutorAdded(0L, id, execInfo)) } @@ -1176,8 +1460,18 @@ class ExecutorAllocationManagerSuite extends SparkFunSuite { post(SparkListenerTaskEnd(1, 1, "foo", Success, info, new ExecutorMetrics, null)) } - private def removeExecutor(manager: ExecutorAllocationManager, executorId: String): Boolean = { - val executorsRemoved = removeExecutors(manager, Seq(executorId)) + private def removeExecutorDefaultProfile( + manager: ExecutorAllocationManager, + executorId: String): Boolean = { + val executorsRemoved = removeExecutorsDefaultProfile(manager, Seq(executorId)) + executorsRemoved.nonEmpty && executorsRemoved(0) == executorId + } + + private def removeExecutor( + manager: ExecutorAllocationManager, + executorId: String, + rpId: Int): Boolean = { + val executorsRemoved = removeExecutors(manager, Seq((executorId, rpId))) executorsRemoved.nonEmpty && executorsRemoved(0) == executorId } @@ -1199,10 +1493,11 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { stageId: Int, numTasks: Int, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty, - attemptId: Int = 0 + attemptId: Int = 0, + rp: ResourceProfile = defaultProfile ): StageInfo = { new StageInfo(stageId, attemptId, "name", numTasks, Seq.empty, Seq.empty, "no details", - taskLocalityPreferences = taskLocalityPreferences) + taskLocalityPreferences = taskLocalityPreferences, resourceProfileId = rp.id) } private def createTaskInfo( @@ -1217,54 +1512,117 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { | Helper methods for accessing private methods and fields | * ------------------------------------------------------- */ - private val _numExecutorsToAdd = PrivateMethod[Int](Symbol("numExecutorsToAdd")) - private val _numExecutorsTarget = PrivateMethod[Int](Symbol("numExecutorsTarget")) - private val _maxNumExecutorsNeeded = PrivateMethod[Int](Symbol("maxNumExecutorsNeeded")) + private val _numExecutorsToAddPerResourceProfileId = + PrivateMethod[mutable.HashMap[Int, Int]]( + Symbol("numExecutorsToAddPerResourceProfileId")) + private val _numExecutorsTargetPerResourceProfileId = + PrivateMethod[mutable.HashMap[Int, Int]]( + Symbol("numExecutorsTargetPerResourceProfileId")) + private val _maxNumExecutorsNeededPerResourceProfile = + PrivateMethod[Int](Symbol("maxNumExecutorsNeededPerResourceProfile")) private val _addTime = PrivateMethod[Long](Symbol("addTime")) private val _schedule = PrivateMethod[Unit](Symbol("schedule")) - private val _addExecutors = PrivateMethod[Int](Symbol("addExecutors")) + private val _doUpdateRequest = PrivateMethod[Unit](Symbol("doUpdateRequest")) private val _updateAndSyncNumExecutorsTarget = PrivateMethod[Int](Symbol("updateAndSyncNumExecutorsTarget")) + private val _addExecutorsToTarget = PrivateMethod[Int](Symbol("addExecutorsToTarget")) private val _removeExecutors = PrivateMethod[Seq[String]](Symbol("removeExecutors")) private val _onSchedulerBacklogged = PrivateMethod[Unit](Symbol("onSchedulerBacklogged")) private val _onSchedulerQueueEmpty = PrivateMethod[Unit](Symbol("onSchedulerQueueEmpty")) - private val _localityAwareTasks = PrivateMethod[Int](Symbol("localityAwareTasks")) - private val _hostToLocalTaskCount = - PrivateMethod[Map[String, Int]](Symbol("hostToLocalTaskCount")) + private val _localityAwareTasksPerResourceProfileId = + PrivateMethod[mutable.HashMap[Int, Int]](Symbol("numLocalityAwareTasksPerResourceProfileId")) + private val _rpIdToHostToLocalTaskCount = + PrivateMethod[Map[Int, Map[String, Int]]](Symbol("rpIdToHostToLocalTaskCount")) private val _onSpeculativeTaskSubmitted = PrivateMethod[Unit](Symbol("onSpeculativeTaskSubmitted")) - private val _totalRunningTasks = PrivateMethod[Int](Symbol("totalRunningTasks")) + private val _totalRunningTasksPerResourceProfile = + PrivateMethod[Int](Symbol("totalRunningTasksPerResourceProfile")) + + private val defaultProfile = ResourceProfile.getOrCreateDefaultProfile(new SparkConf) + + private def numExecutorsToAddForDefaultProfile(manager: ExecutorAllocationManager): Int = { + numExecutorsToAdd(manager, defaultProfile) + } + + private def numExecutorsToAdd( + manager: ExecutorAllocationManager, + rp: ResourceProfile): Int = { + val nmap = manager invokePrivate _numExecutorsToAddPerResourceProfileId() + nmap(rp.id) + } + + private def updateAndSyncNumExecutorsTarget( + manager: ExecutorAllocationManager, + now: Long): Unit = { + manager invokePrivate _updateAndSyncNumExecutorsTarget(now) + } + + private def numExecutorsTargetForDefaultProfileId(manager: ExecutorAllocationManager): Int = { + numExecutorsTarget(manager, defaultProfile.id) + } - private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _numExecutorsToAdd() + private def numExecutorsTarget( + manager: ExecutorAllocationManager, + rpId: Int): Int = { + val numMap = manager invokePrivate _numExecutorsTargetPerResourceProfileId() + numMap(rpId) } - private def numExecutorsTarget(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _numExecutorsTarget() + private def addExecutorsToTargetForDefaultProfile( + manager: ExecutorAllocationManager, + updatesNeeded: mutable.HashMap[ResourceProfile, + ExecutorAllocationManager.TargetNumUpdates] + ): Int = { + addExecutorsToTarget(manager, updatesNeeded, defaultProfile) + } + + private def addExecutorsToTarget( + manager: ExecutorAllocationManager, + updatesNeeded: mutable.HashMap[ResourceProfile, + ExecutorAllocationManager.TargetNumUpdates], + rp: ResourceProfile + ): Int = { + val maxNumExecutorsNeeded = + manager invokePrivate _maxNumExecutorsNeededPerResourceProfile(rp.id) + manager invokePrivate + _addExecutorsToTarget(maxNumExecutorsNeeded, rp.id, updatesNeeded) } private def addTime(manager: ExecutorAllocationManager): Long = { manager invokePrivate _addTime() } - private def schedule(manager: ExecutorAllocationManager): Unit = { - manager invokePrivate _schedule() + private def doUpdateRequest( + manager: ExecutorAllocationManager, + updates: Map[ResourceProfile, ExecutorAllocationManager.TargetNumUpdates], + now: Long): Unit = { + manager invokePrivate _doUpdateRequest(updates, now) } - private def maxNumExecutorsNeeded(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _maxNumExecutorsNeeded() + private def schedule(manager: ExecutorAllocationManager): Unit = { + manager invokePrivate _schedule() } - private def addExecutors(manager: ExecutorAllocationManager): Int = { - val maxNumExecutorsNeeded = manager invokePrivate _maxNumExecutorsNeeded() - manager invokePrivate _addExecutors(maxNumExecutorsNeeded) + private def maxNumExecutorsNeededPerResourceProfile( + manager: ExecutorAllocationManager, + rp: ResourceProfile): Int = { + manager invokePrivate _maxNumExecutorsNeededPerResourceProfile(rp.id) } private def adjustRequestedExecutors(manager: ExecutorAllocationManager): Int = { manager invokePrivate _updateAndSyncNumExecutorsTarget(0L) } - private def removeExecutors(manager: ExecutorAllocationManager, ids: Seq[String]): Seq[String] = { + private def removeExecutorsDefaultProfile( + manager: ExecutorAllocationManager, + ids: Seq[String]): Seq[String] = { + val idsAndProfileIds = ids.map((_, defaultProfile.id)) + manager invokePrivate _removeExecutors(idsAndProfileIds) + } + + private def removeExecutors( + manager: ExecutorAllocationManager, + ids: Seq[(String, Int)]): Seq[String] = { manager invokePrivate _removeExecutors(ids) } @@ -1280,15 +1638,22 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { manager invokePrivate _onSpeculativeTaskSubmitted(id) } - private def localityAwareTasks(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _localityAwareTasks() + private def localityAwareTasksForDefaultProfile(manager: ExecutorAllocationManager): Int = { + val localMap = manager invokePrivate _localityAwareTasksPerResourceProfileId() + localMap(defaultProfile.id) + } + + private def totalRunningTasksPerResourceProfile(manager: ExecutorAllocationManager): Int = { + manager invokePrivate _totalRunningTasksPerResourceProfile(defaultProfile.id) } - private def totalRunningTasks(manager: ExecutorAllocationManager): Int = { - manager invokePrivate _totalRunningTasks() + private def hostToLocalTaskCount( + manager: ExecutorAllocationManager): Map[String, Int] = { + val rpIdToHostLocal = manager invokePrivate _rpIdToHostToLocalTaskCount() + rpIdToHostLocal(defaultProfile.id) } - private def hostToLocalTaskCount(manager: ExecutorAllocationManager): Map[String, Int] = { - manager invokePrivate _hostToLocalTaskCount() + private def getResourceProfileIdOfExecutor(manager: ExecutorAllocationManager): Int = { + defaultProfile.id } } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index ff0f2f9134ed3..a9296955d18b4 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -285,9 +285,14 @@ private class FakeSchedulerBackend( clusterManagerEndpoint: RpcEndpointRef) extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { - protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { + protected override def doRequestTotalExecutors( + resourceProfileToTotalExecs: Map[ResourceProfile, Int]): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty)) + RequestExecutors( + resourceProfileToTotalExecs(ResourceProfile.getOrCreateDefaultProfile(conf)), + numLocalityAwareTasksPerResourceProfileId(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID), + rpHostToLocalTaskCount(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID), + Set.empty)) } protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 1fe12e116d96e..599ea8955491f 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -44,7 +44,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self def resetSparkContext(): Unit = { LocalSparkContext.stop(sc) - ResourceProfile.clearDefaultProfile + ResourceProfile.clearDefaultProfile() sc = null } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index df9c7c5eaa368..b6dfa69015c28 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -36,6 +36,7 @@ import org.scalatest.concurrent.Eventually import org.apache.spark.TestUtils._ import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests._ import org.apache.spark.internal.config.UI._ import org.apache.spark.resource.ResourceAllocation import org.apache.spark.resource.ResourceUtils._ @@ -784,7 +785,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } test(s"Avoid setting ${CPUS_PER_TASK.key} unreasonably (SPARK-27192)") { - val FAIL_REASON = s"has to be >= the task config: ${CPUS_PER_TASK.key}" + val FAIL_REASON = " has to be >= the number of cpus per task" Seq( ("local", 2, None), ("local[2]", 3, None), @@ -864,9 +865,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(conf) }.getMessage() - assert(error.contains("The executor resource config: spark.executor.resource.gpu.amount " + - "needs to be specified since a task requirement config: spark.task.resource.gpu.amount " + - "was specified")) + assert(error.contains("No executor resource configs were not specified for the following " + + "task configs: gpu")) } test("Test parsing resources executor config < task requirements") { @@ -880,15 +880,15 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc = new SparkContext(conf) }.getMessage() - assert(error.contains("The executor resource config: spark.executor.resource.gpu.amount = 1 " + - "has to be >= the requested amount in task resource config: " + - "spark.task.resource.gpu.amount = 2")) + assert(error.contains("The executor resource: gpu, amount: 1 needs to be >= the task " + + "resource request amount of 2.0")) } test("Parse resources executor config not the same multiple numbers of the task requirements") { val conf = new SparkConf() .setMaster("local-cluster[1, 1, 1024]") .setAppName("test-cluster") + conf.set(RESOURCES_WARNING_TESTING, true) conf.set(TASK_GPU_ID.amountConf, "2") conf.set(EXECUTOR_GPU_ID.amountConf, "4") @@ -897,25 +897,9 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu }.getMessage() assert(error.contains( - "The configuration of resource: gpu (exec = 4, task = 2, runnable tasks = 2) will result " + - "in wasted resources due to resource CPU limiting the number of runnable tasks per " + - "executor to: 1. Please adjust your configuration.")) - } - - test("Parse resources executor config cpus not limiting resource") { - val conf = new SparkConf() - .setMaster("local-cluster[1, 8, 1024]") - .setAppName("test-cluster") - conf.set(TASK_GPU_ID.amountConf, "2") - conf.set(EXECUTOR_GPU_ID.amountConf, "4") - - var error = intercept[IllegalArgumentException] { - sc = new SparkContext(conf) - }.getMessage() - - assert(error.contains("The number of slots on an executor has to be " + - "limited by the number of cores, otherwise you waste resources and " + - "dynamic allocation doesn't work properly")) + "The configuration of resource: gpu (exec = 4, task = 2.0/1, runnable tasks = 2) will " + + "result in wasted resources due to resource cpus limiting the number of runnable " + + "tasks per executor to: 1. Please adjust your configuration.")) } test("test resource scheduling under local-cluster mode") { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/BasicEventFilterBuilderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/BasicEventFilterBuilderSuite.scala index 86511ae08784a..c905797bf1287 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/BasicEventFilterBuilderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/BasicEventFilterBuilderSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.history import org.apache.spark.{SparkFunSuite, Success, TaskResultLost, TaskState} import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler._ import org.apache.spark.status.ListenerEventsTestHelper @@ -141,7 +142,8 @@ class BasicEventFilterBuilderSuite extends SparkFunSuite { // - Re-submit stage 1, all tasks, and succeed them and the stage. val oldS1 = stages.last val newS1 = new StageInfo(oldS1.stageId, oldS1.attemptNumber + 1, oldS1.name, oldS1.numTasks, - oldS1.rddInfos, oldS1.parentIds, oldS1.details, oldS1.taskMetrics) + oldS1.rddInfos, oldS1.parentIds, oldS1.details, oldS1.taskMetrics, + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) time += 1 newS1.submissionTime = Some(time) diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala new file mode 100644 index 0000000000000..075260317284d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/resource/ResourceProfileManagerSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.resource + +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests._ + +class ResourceProfileManagerSuite extends SparkFunSuite { + + override def beforeAll() { + try { + ResourceProfile.clearDefaultProfile() + } finally { + super.beforeAll() + } + } + + override def afterEach() { + try { + ResourceProfile.clearDefaultProfile() + } finally { + super.afterEach() + } + } + + test("ResourceProfileManager") { + val conf = new SparkConf().set(EXECUTOR_CORES, 4) + val rpmanager = new ResourceProfileManager(conf) + val defaultProf = rpmanager.defaultResourceProfile + assert(defaultProf.id === ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + assert(defaultProf.executorResources.size === 2, + "Executor resources should contain cores and memory by default") + assert(defaultProf.executorResources(ResourceProfile.CORES).amount === 4, + s"Executor resources should have 4 cores") + } + + test("isSupported yarn no dynamic allocation") { + val conf = new SparkConf().setMaster("yarn").set(EXECUTOR_CORES, 4) + conf.set(RESOURCE_PROFILE_MANAGER_TESTING.key, "true") + val rpmanager = new ResourceProfileManager(conf) + // default profile should always work + val defaultProf = rpmanager.defaultResourceProfile + val rprof = new ResourceProfileBuilder() + val gpuExecReq = + new ExecutorResourceRequests().resource("gpu", 2, "someScript") + val immrprof = rprof.require(gpuExecReq).build + val error = intercept[SparkException] { + rpmanager.isSupported(immrprof) + }.getMessage() + + assert(error.contains("ResourceProfiles are only supported on YARN with dynamic allocation")) + } + + test("isSupported yarn with dynamic allocation") { + val conf = new SparkConf().setMaster("yarn").set(EXECUTOR_CORES, 4) + conf.set(DYN_ALLOCATION_ENABLED, true) + conf.set(RESOURCE_PROFILE_MANAGER_TESTING.key, "true") + val rpmanager = new ResourceProfileManager(conf) + // default profile should always work + val defaultProf = rpmanager.defaultResourceProfile + val rprof = new ResourceProfileBuilder() + val gpuExecReq = + new ExecutorResourceRequests().resource("gpu", 2, "someScript") + val immrprof = rprof.require(gpuExecReq).build + assert(rpmanager.isSupported(immrprof) == true) + } + + test("isSupported yarn with local mode") { + val conf = new SparkConf().setMaster("local").set(EXECUTOR_CORES, 4) + conf.set(RESOURCE_PROFILE_MANAGER_TESTING.key, "true") + val rpmanager = new ResourceProfileManager(conf) + // default profile should always work + val defaultProf = rpmanager.defaultResourceProfile + val rprof = new ResourceProfileBuilder() + val gpuExecReq = + new ExecutorResourceRequests().resource("gpu", 2, "someScript") + val immrprof = rprof.require(gpuExecReq).build + var error = intercept[SparkException] { + rpmanager.isSupported(immrprof) + }.getMessage() + + assert(error.contains("ResourceProfiles are only supported on YARN with dynamic allocation")) + } + + + +} diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceProfileSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceProfileSuite.scala index c0637eeeacaba..b2f2c3632e454 100644 --- a/core/src/test/scala/org/apache/spark/resource/ResourceProfileSuite.scala +++ b/core/src/test/scala/org/apache/spark/resource/ResourceProfileSuite.scala @@ -18,18 +18,28 @@ package org.apache.spark.resource import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.internal.config.{EXECUTOR_CORES, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, SPARK_EXECUTOR_PREFIX} +import org.apache.spark.internal.config.{EXECUTOR_CORES, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} import org.apache.spark.internal.config.Python.PYSPARK_EXECUTOR_MEMORY +import org.apache.spark.resource.TestResourceIDs._ class ResourceProfileSuite extends SparkFunSuite { + override def beforeAll() { + try { + ResourceProfile.clearDefaultProfile() + } finally { + super.beforeAll() + } + } + override def afterEach() { try { - ResourceProfile.clearDefaultProfile + ResourceProfile.clearDefaultProfile() } finally { super.afterEach() } } + test("Default ResourceProfile") { val rprof = ResourceProfile.getOrCreateDefaultProfile(new SparkConf) assert(rprof.id === ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) @@ -59,18 +69,19 @@ class ResourceProfileSuite extends SparkFunSuite { conf.set(EXECUTOR_MEMORY_OVERHEAD.key, "1g") conf.set(EXECUTOR_MEMORY.key, "4g") conf.set(EXECUTOR_CORES.key, "4") - conf.set("spark.task.resource.gpu.amount", "1") - conf.set(s"$SPARK_EXECUTOR_PREFIX.resource.gpu.amount", "1") - conf.set(s"$SPARK_EXECUTOR_PREFIX.resource.gpu.discoveryScript", "nameOfScript") + conf.set(TASK_GPU_ID.amountConf, "1") + conf.set(EXECUTOR_GPU_ID.amountConf, "1") + conf.set(EXECUTOR_GPU_ID.discoveryScriptConf, "nameOfScript") val rprof = ResourceProfile.getOrCreateDefaultProfile(conf) assert(rprof.id === ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val execResources = rprof.executorResources - assert(execResources.size === 5, - "Executor resources should contain cores, memory, and gpu " + execResources) + assert(execResources.size === 5, s"Executor resources should contain cores, pyspark " + + s"memory, memory overhead, memory, and gpu $execResources") assert(execResources.contains("gpu"), "Executor resources should have gpu") assert(rprof.executorResources(ResourceProfile.CORES).amount === 4, "Executor resources should have 4 core") - assert(rprof.getExecutorCores.get === 4, "Executor resources should have 4 core") + assert(rprof.getExecutorCores.get === 4, + "Executor resources should have 4 core") assert(rprof.executorResources(ResourceProfile.MEMORY).amount === 4096, "Executor resources should have 1024 memory") assert(rprof.executorResources(ResourceProfile.PYSPARK_MEM).amount == 2048, @@ -84,12 +95,60 @@ class ResourceProfileSuite extends SparkFunSuite { test("test default profile task gpus fractional") { val sparkConf = new SparkConf() - .set("spark.executor.resource.gpu.amount", "2") - .set("spark.task.resource.gpu.amount", "0.33") + .set(EXECUTOR_GPU_ID.amountConf, "2") + .set(TASK_GPU_ID.amountConf, "0.33") val immrprof = ResourceProfile.getOrCreateDefaultProfile(sparkConf) assert(immrprof.taskResources.get("gpu").get.amount == 0.33) } + test("maxTasksPerExecutor cpus") { + val sparkConf = new SparkConf() + .set(EXECUTOR_CORES, 1) + val rprof = new ResourceProfileBuilder() + val taskReq = new TaskResourceRequests().resource("gpu", 1) + val execReq = + new ExecutorResourceRequests().resource("gpu", 2, "myscript", "nvidia") + rprof.require(taskReq).require(execReq) + val immrprof = new ResourceProfile(rprof.executorResources, rprof.taskResources) + assert(immrprof.limitingResource(sparkConf) == "cpus") + assert(immrprof.maxTasksPerExecutor(sparkConf) == 1) + } + + test("maxTasksPerExecutor/limiting no executor cores") { + val sparkConf = new SparkConf().setMaster("spark://testing") + val rprof = new ResourceProfileBuilder() + val taskReq = new TaskResourceRequests().resource("gpu", 1) + val execReq = + new ExecutorResourceRequests().resource("gpu", 2, "myscript", "nvidia") + rprof.require(taskReq).require(execReq) + val immrprof = new ResourceProfile(rprof.executorResources, rprof.taskResources) + assert(immrprof.limitingResource(sparkConf) == "gpu") + assert(immrprof.maxTasksPerExecutor(sparkConf) == 2) + assert(immrprof.isCoresLimitKnown == false) + } + + test("maxTasksPerExecutor/limiting no other resource no executor cores") { + val sparkConf = new SparkConf().setMaster("spark://testing") + val immrprof = ResourceProfile.getOrCreateDefaultProfile(sparkConf) + assert(immrprof.limitingResource(sparkConf) == "") + assert(immrprof.maxTasksPerExecutor(sparkConf) == 1) + assert(immrprof.isCoresLimitKnown == false) + } + + test("maxTasksPerExecutor/limiting executor cores") { + val sparkConf = new SparkConf().setMaster("spark://testing").set(EXECUTOR_CORES, 2) + val rprof = new ResourceProfileBuilder() + val taskReq = new TaskResourceRequests().resource("gpu", 1) + val execReq = + new ExecutorResourceRequests().resource("gpu", 2, "myscript", "nvidia") + rprof.require(taskReq).require(execReq) + val immrprof = new ResourceProfile(rprof.executorResources, rprof.taskResources) + assert(immrprof.limitingResource(sparkConf) == ResourceProfile.CPUS) + assert(immrprof.maxTasksPerExecutor(sparkConf) == 2) + assert(immrprof.isCoresLimitKnown == true) + } + + test("Create ResourceProfile") { val rprof = new ResourceProfileBuilder() val taskReq = new TaskResourceRequests().resource("gpu", 1) diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceUtilsSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceUtilsSuite.scala index dffe9a02e9aa4..278a72a7192d8 100644 --- a/core/src/test/scala/org/apache/spark/resource/ResourceUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/resource/ResourceUtilsSuite.scala @@ -26,8 +26,10 @@ import org.json4s.{DefaultFormats, Extraction} import org.apache.spark.{LocalSparkContext, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.TestUtils._ import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.Tests._ import org.apache.spark.resource.ResourceUtils._ import org.apache.spark.resource.TestResourceIDs._ +import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.util.Utils class ResourceUtilsSuite extends SparkFunSuite @@ -165,6 +167,7 @@ class ResourceUtilsSuite extends SparkFunSuite val rpBuilder = new ResourceProfileBuilder() val ereqs = new ExecutorResourceRequests().resource(GPU, 2, gpuDiscovery) val treqs = new TaskResourceRequests().resource(GPU, 1) + val rp = rpBuilder.require(ereqs).require(treqs).build val resourcesFromBoth = getOrDiscoverAllResourcesForResourceProfile( Some(resourcesFile), SPARK_EXECUTOR_PREFIX, rp, conf) diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index c063301673598..7666c6c7810cc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark._ import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network.RPC_MESSAGE_MAX_SIZE import org.apache.spark.rdd.RDD -import org.apache.spark.resource.{ResourceInformation, ResourceProfile} +import org.apache.spark.resource.{ExecutorResourceRequests, ResourceInformation, ResourceProfile, TaskResourceRequests} import org.apache.spark.resource.ResourceUtils._ import org.apache.spark.resource.TestResourceIDs._ import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} @@ -187,8 +187,6 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo } test("extra resources from executor") { - import TestUtils._ - val conf = new SparkConf() .set(EXECUTOR_CORES, 1) .set(SCHEDULER_REVIVE_INTERVAL.key, "1m") // don't let it auto revive during test @@ -200,6 +198,11 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo conf.set(EXECUTOR_GPU_ID.amountConf, "1") sc = new SparkContext(conf) + val execGpu = new ExecutorResourceRequests().cores(1).resource(GPU, 3) + val taskGpu = new TaskResourceRequests().cpus(1).resource(GPU, 1) + val rp = new ResourceProfile(execGpu.requests, taskGpu.requests) + sc.resourceProfileManager.addResourceProfile(rp) + assert(rp.id > ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val backend = sc.schedulerBackend.asInstanceOf[TestCoarseGrainedSchedulerBackend] val mockEndpointRef = mock[RpcEndpointRef] val mockAddress = mock[RpcAddress] @@ -224,7 +227,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) backend.driverEndpoint.askSync[Boolean]( RegisterExecutor("3", mockEndpointRef, mockAddress.host, 1, Map.empty, Map.empty, resources, - 5)) + rp.id)) val frameSize = RpcUtils.maxMessageSizeBytes(sc.conf) val bytebuffer = java.nio.ByteBuffer.allocate(frameSize - 100) @@ -234,7 +237,7 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo assert(execResources(GPU).availableAddrs.sorted === Array("0", "1", "3")) var exec3ResourceProfileId = backend.getExecutorResourceProfileId("3") - assert(exec3ResourceProfileId === 5) + assert(exec3ResourceProfileId === rp.id) val taskResources = Map(GPU -> new ResourceInformation(GPU, Array("0"))) var taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1, 0, "1", diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 286924001e920..61ea21fa86c5a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.io._ import org.apache.spark.metrics.{ExecutorMetricType, MetricsSystem} +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.{JsonProtocol, Utils} @@ -438,12 +439,14 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit private def createStageSubmittedEvent(stageId: Int) = { SparkListenerStageSubmitted(new StageInfo(stageId, 0, stageId.toString, 0, - Seq.empty, Seq.empty, "details")) + Seq.empty, Seq.empty, "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) } private def createStageCompletedEvent(stageId: Int) = { SparkListenerStageCompleted(new StageInfo(stageId, 0, stageId.toString, 0, - Seq.empty, Seq.empty, "details")) + Seq.empty, Seq.empty, "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) } private def createExecutorAddedEvent(executorId: Int) = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala index 615389ae5c2d4..3596a9ebb1f5a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/dynalloc/ExecutorMonitorSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark._ import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config._ import org.apache.spark.resource.ResourceProfile.{DEFAULT_RESOURCE_PROFILE_ID, UNKNOWN_RESOURCE_PROFILE_ID} +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage._ @@ -255,25 +256,28 @@ class ExecutorMonitorSuite extends SparkFunSuite { test("track executors pending for removal") { knownExecs ++= Set("1", "2", "3") + val execInfoRp1 = new ExecutorInfo("host1", 1, Map.empty, + Map.empty, Map.empty, 1) + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "1", execInfo)) monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "2", execInfo)) - monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "3", execInfo)) + monitor.onExecutorAdded(SparkListenerExecutorAdded(clock.getTimeMillis(), "3", execInfoRp1)) clock.setTime(idleDeadline) - assert(monitor.timedOutExecutors().toSet === Set("1", "2", "3")) + assert(monitor.timedOutExecutors().toSet === Set(("1", 0), ("2", 0), ("3", 1))) assert(monitor.pendingRemovalCount === 0) // Notify that only a subset of executors was killed, to mimic the case where the scheduler // refuses to kill an executor that is busy for whatever reason the monitor hasn't detected yet. monitor.executorsKilled(Seq("1")) - assert(monitor.timedOutExecutors().toSet === Set("2", "3")) + assert(monitor.timedOutExecutors().toSet === Set(("2", 0), ("3", 1))) assert(monitor.pendingRemovalCount === 1) // Check the timed out executors again so that we're sure they're still timed out when no // events happen. This ensures that the monitor doesn't lose track of them. - assert(monitor.timedOutExecutors().toSet === Set("2", "3")) + assert(monitor.timedOutExecutors().toSet === Set(("2", 0), ("3", 1))) monitor.onTaskStart(SparkListenerTaskStart(1, 1, taskInfo("2", 1))) - assert(monitor.timedOutExecutors().toSet === Set("3")) + assert(monitor.timedOutExecutors().toSet === Set(("3", 1))) monitor.executorsKilled(Seq("3")) assert(monitor.pendingRemovalCount === 2) @@ -282,7 +286,7 @@ class ExecutorMonitorSuite extends SparkFunSuite { new ExecutorMetrics, null)) assert(monitor.timedOutExecutors().isEmpty) clock.advance(idleDeadline) - assert(monitor.timedOutExecutors().toSet === Set("2")) + assert(monitor.timedOutExecutors().toSet === Set(("2", 0))) } test("shuffle block tracking") { @@ -435,7 +439,8 @@ class ExecutorMonitorSuite extends SparkFunSuite { private def stageInfo(id: Int, shuffleId: Int = -1): StageInfo = { new StageInfo(id, 0, s"stage$id", 1, Nil, Nil, "", - shuffleDepId = if (shuffleId >= 0) Some(shuffleId) else None) + shuffleDepId = if (shuffleId >= 0) Some(shuffleId) else None, + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) } private def taskInfo( diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 255f91866ef58..24eb1685f577a 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.config.Status._ import org.apache.spark.metrics.ExecutorMetricType +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster._ import org.apache.spark.status.ListenerEventsTestHelper._ @@ -151,8 +152,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Start a job with 2 stages / 4 tasks each time += 1 val stages = Seq( - new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1"), - new StageInfo(2, 0, "stage2", 4, Nil, Seq(1), "details2")) + new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID), + new StageInfo(2, 0, "stage2", 4, Nil, Seq(1), "details2", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) val jobProps = new Properties() jobProps.setProperty(SparkContext.SPARK_JOB_DESCRIPTION, "jobDescription") @@ -524,7 +527,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // - Re-submit stage 2, all tasks, and succeed them and the stage. val oldS2 = stages.last val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptNumber + 1, oldS2.name, oldS2.numTasks, - oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) + oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics, + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) time += 1 newS2.submissionTime = Some(time) @@ -575,8 +579,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // change the stats of the already finished job. time += 1 val j2Stages = Seq( - new StageInfo(3, 0, "stage1", 4, Nil, Nil, "details1"), - new StageInfo(4, 0, "stage2", 4, Nil, Seq(3), "details2")) + new StageInfo(3, 0, "stage1", 4, Nil, Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID), + new StageInfo(4, 0, "stage2", 4, Nil, Seq(3), "details2", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) j2Stages.last.submissionTime = Some(time) listener.onJobStart(SparkListenerJobStart(2, time, j2Stages, null)) assert(store.count(classOf[JobDataWrapper]) === 2) @@ -703,7 +709,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Submit a stage for the first RDD before it's marked for caching, to make sure later // the listener picks up the correct storage level. val rdd1Info = new RDDInfo(rdd1b1.rddId, "rdd1", 2, StorageLevel.NONE, false, Nil) - val stage0 = new StageInfo(0, 0, "stage0", 4, Seq(rdd1Info), Nil, "details0") + val stage0 = new StageInfo(0, 0, "stage0", 4, Seq(rdd1Info), Nil, "details0", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) listener.onStageSubmitted(SparkListenerStageSubmitted(stage0, new Properties())) listener.onStageCompleted(SparkListenerStageCompleted(stage0)) assert(store.count(classOf[RDDStorageInfoWrapper]) === 0) @@ -711,7 +718,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Submit a stage and make sure the RDDs are recorded. rdd1Info.storageLevel = level val rdd2Info = new RDDInfo(rdd2b1.rddId, "rdd2", 1, level, false, Nil) - val stage = new StageInfo(1, 0, "stage1", 4, Seq(rdd1Info, rdd2Info), Nil, "details1") + val stage = new StageInfo(1, 0, "stage1", 4, Seq(rdd1Info, rdd2Info), Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => @@ -1018,9 +1026,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // data is not deleted. time += 1 val stages = Seq( - new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1"), - new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2"), - new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3")) + new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID), + new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID), + new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) // Graph data is generated by the job start event, so fire it. listener.onJobStart(SparkListenerJobStart(4, time, stages, null)) @@ -1068,7 +1079,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { } assert(store.count(classOf[CachedQuantile], "stage", key(dropped)) === 0) - val attempt2 = new StageInfo(3, 1, "stage3", 4, Nil, Nil, "details3") + val attempt2 = new StageInfo(3, 1, "stage3", 4, Nil, Nil, "details3", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) time += 1 attempt2.submissionTime = Some(time) listener.onStageSubmitted(SparkListenerStageSubmitted(attempt2, new Properties())) @@ -1139,9 +1151,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) val listener = new AppStatusListener(store, testConf, true) - val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") - val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") - val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) // Start stage 1 and stage 2 time += 1 @@ -1172,8 +1187,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val testConf = conf.clone().set(MAX_RETAINED_STAGES, 2) val listener = new AppStatusListener(store, testConf, true) - val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") - val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) // Sart job 1 time += 1 @@ -1193,7 +1210,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) // Submit stage 3 and verify stage 2 is evicted - val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) time += 1 stage3.submissionTime = Some(time) listener.onStageSubmitted(SparkListenerStageSubmitted(stage3, new Properties())) @@ -1208,7 +1226,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val testConf = conf.clone().set(MAX_RETAINED_TASKS_PER_STAGE, 2) val listener = new AppStatusListener(store, testConf, true) - val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) stage1.submissionTime = Some(time) listener.onStageSubmitted(SparkListenerStageSubmitted(stage1, new Properties())) @@ -1243,9 +1262,12 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { val listener = new AppStatusListener(store, testConf, true) val appStore = new AppStatusStore(store) - val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") - val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") - val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3") + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + val stage3 = new StageInfo(3, 0, "stage3", 4, Nil, Nil, "details3", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) time += 1 stage1.submissionTime = Some(time) @@ -1274,8 +1296,10 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { test("SPARK-24415: update metrics for tasks that finish late") { val listener = new AppStatusListener(store, conf, true) - val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1") - val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2") + val stage1 = new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) + val stage2 = new StageInfo(2, 0, "stage2", 4, Nil, Nil, "details2", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) // Start job listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage1, stage2), null)) @@ -1340,7 +1364,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { listener.onExecutorAdded(createExecutorAddedEvent(1)) listener.onExecutorAdded(createExecutorAddedEvent(2)) - val stage = new StageInfo(1, 0, "stage", 4, Nil, Nil, "details") + val stage = new StageInfo(1, 0, "stage", 4, Nil, Nil, "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) listener.onJobStart(SparkListenerJobStart(1, time, Seq(stage), null)) listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) @@ -1577,7 +1602,8 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { // Submit a stage and make sure the RDDs are recorded. val rdd1Info = new RDDInfo(rdd1b1.rddId, "rdd1", 2, level, false, Nil) - val stage = new StageInfo(1, 0, "stage1", 4, Seq(rdd1Info), Nil, "details1") + val stage = new StageInfo(1, 0, "stage1", 4, Seq(rdd1Info), Nil, "details1", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) // Add partition 1 replicated on two block managers. diff --git a/core/src/test/scala/org/apache/spark/status/ListenerEventsTestHelper.scala b/core/src/test/scala/org/apache/spark/status/ListenerEventsTestHelper.scala index 4b3fbacc47f9c..99c0d9593ccae 100644 --- a/core/src/test/scala/org/apache/spark/status/ListenerEventsTestHelper.scala +++ b/core/src/test/scala/org/apache/spark/status/ListenerEventsTestHelper.scala @@ -23,6 +23,7 @@ import scala.collection.immutable.Map import org.apache.spark.{AccumulatorSuite, SparkContext, Success, TaskState} import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded, SparkListenerExecutorMetricsUpdate, SparkListenerExecutorRemoved, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerStageSubmitted, SparkListenerTaskEnd, SparkListenerTaskStart, StageInfo, TaskInfo, TaskLocality} import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{RDDInfo, StorageLevel} @@ -61,7 +62,8 @@ object ListenerEventsTestHelper { } def createStage(id: Int, rdds: Seq[RDDInfo], parentIds: Seq[Int]): StageInfo = { - new StageInfo(id, 0, s"stage${id}", 4, rdds, parentIds, s"details${id}") + new StageInfo(id, 0, s"stage${id}", 4, rdds, parentIds, s"details${id}", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) } def createStage(rdds: Seq[RDDInfo], parentIds: Seq[Int]): StageInfo = { @@ -96,13 +98,15 @@ object ListenerEventsTestHelper { /** Create a stage submitted event for the specified stage Id. */ def createStageSubmittedEvent(stageId: Int): SparkListenerStageSubmitted = { SparkListenerStageSubmitted(new StageInfo(stageId, 0, stageId.toString, 0, - Seq.empty, Seq.empty, "details")) + Seq.empty, Seq.empty, "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) } /** Create a stage completed event for the specified stage Id. */ def createStageCompletedEvent(stageId: Int): SparkListenerStageCompleted = { SparkListenerStageCompleted(new StageInfo(stageId, 0, stageId.toString, 0, - Seq.empty, Seq.empty, "details")) + Seq.empty, Seq.empty, "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) } def createExecutorAddedEvent(executorId: Int): SparkListenerExecutorAdded = { diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index bd18e9e628da8..7711934cbe8a6 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -27,6 +27,7 @@ import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} import org.apache.spark._ import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.config.Status._ +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler._ import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1.{AccumulableInfo => UIAccumulableInfo, StageData, StageStatus} @@ -131,7 +132,8 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val page = new StagePage(tab, statusStore) // Simulate a stage in job progress listener - val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details") + val stageInfo = new StageInfo(0, 0, "dummy", 1, Seq.empty, Seq.empty, "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) // Simulate two tasks to test PEAK_EXECUTION_MEMORY correctness (1 to 2).foreach { taskId => diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a2a4b3aa974fc..edc0662a0f73e 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -32,8 +32,7 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.metrics.ExecutorMetricType import org.apache.spark.rdd.RDDOperationScope -import org.apache.spark.resource.ResourceInformation -import org.apache.spark.resource.ResourceUtils +import org.apache.spark.resource.{ResourceInformation, ResourceProfile, ResourceUtils} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.shuffle.MetadataFetchFailedException @@ -341,7 +340,8 @@ class JsonProtocolSuite extends SparkFunSuite { val stageIds = Seq[Int](1, 2, 3, 4) val stageInfos = stageIds.map(x => makeStageInfo(x, x * 200, x * 300, x * 400L, x * 500L)) val dummyStageInfos = - stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")) + stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)) val jobStart = SparkListenerJobStart(10, jobSubmissionTime, stageInfos, properties) val oldEvent = JsonProtocol.jobStartToJson(jobStart).removeField({_._1 == "Stage Infos"}) val expectedJobStart = @@ -383,9 +383,11 @@ class JsonProtocolSuite extends SparkFunSuite { test("StageInfo backward compatibility (parent IDs)") { // Prior to Spark 1.4.0, StageInfo did not have the "Parent IDs" property - val stageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq(1, 2, 3), "details") + val stageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq(1, 2, 3), "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val oldStageInfo = JsonProtocol.stageInfoToJson(stageInfo).removeField({ _._1 == "Parent IDs"}) - val expectedStageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq.empty, "details") + val expectedStageInfo = new StageInfo(1, 1, "me-stage", 1, Seq.empty, Seq.empty, "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) assertEquals(expectedStageInfo, JsonProtocol.stageInfoFromJson(oldStageInfo)) } @@ -873,7 +875,8 @@ private[spark] object JsonProtocolSuite extends Assertions { private def makeStageInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { val rddInfos = (0 until a % 5).map { i => makeRddInfo(a + i, b + i, c + i, d + i, e + i) } - val stageInfo = new StageInfo(a, 0, "greetings", b, rddInfos, Seq(100, 200, 300), "details") + val stageInfo = new StageInfo(a, 0, "greetings", b, rddInfos, Seq(100, 200, 300), "details", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val (acc1, acc2) = (makeAccumulableInfo(1), makeAccumulableInfo(2)) stageInfo.accumulables(acc1.id) = acc1 stageInfo.accumulables(acc2.id) = acc2 diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index c7f435a582210..edfea42bed71d 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -275,9 +275,13 @@ def setUp(self): self.tempFile = tempfile.NamedTemporaryFile(delete=False) self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}') self.tempFile.close() + # create temporary directory for Worker resources coordination + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH) conf = SparkConf().set("spark.test.home", SPARK_HOME) + conf = conf.set("spark.resources.dir", self.tempdir.name) conf = conf.set("spark.driver.resource.gpu.amount", "1") conf = conf.set("spark.driver.resource.gpu.discoveryScript", self.tempFile.name) self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf) @@ -292,6 +296,7 @@ def test_resources(self): def tearDown(self): os.unlink(self.tempFile.name) + shutil.rmtree(self.tempdir.name) self.sc.stop() diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 6095a384679af..68cfe814762e0 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -16,6 +16,7 @@ # import os import random +import shutil import stat import sys import tempfile @@ -277,9 +278,13 @@ def setUp(self): self.tempFile = tempfile.NamedTemporaryFile(delete=False) self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}') self.tempFile.close() + # create temporary directory for Worker resources coordination + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP | stat.S_IROTH | stat.S_IXOTH) conf = SparkConf().set("spark.test.home", SPARK_HOME) + conf = conf.set("spark.resources.dir", self.tempdir.name) conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name) conf = conf.set("spark.worker.resource.gpu.amount", 1) conf = conf.set("spark.task.resource.gpu.amount", "1") @@ -297,6 +302,7 @@ def test_resources(self): def tearDown(self): os.unlink(self.tempFile.name) + shutil.rmtree(self.tempdir.name) self.sc.stop() if __name__ == "__main__": diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 105841ac834b3..5655ef50d214f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -27,6 +27,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config.SCHEDULER_MIN_REGISTERED_RESOURCES_RATIO +import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc.RpcAddress import org.apache.spark.scheduler.{ExecutorKilled, ExecutorLossReason, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} @@ -55,6 +56,8 @@ private[spark] class KubernetesClusterSchedulerBackend( private val shouldDeleteExecutors = conf.get(KUBERNETES_DELETE_EXECUTORS) + private val defaultProfile = scheduler.sc.resourceProfileManager.defaultResourceProfile + // Allow removeExecutor to be accessible by ExecutorPodsLifecycleEventHandler private[k8s] def doRemoveExecutor(executorId: String, reason: ExecutorLossReason): Unit = { if (isExecutorActive(executorId)) { @@ -116,8 +119,9 @@ private[spark] class KubernetesClusterSchedulerBackend( } } - override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { - podAllocator.setTotalExpectedExecutors(requestedTotal) + override def doRequestTotalExecutors( + resourceProfileToTotalExecs: Map[ResourceProfile, Int]): Future[Boolean] = { + podAllocator.setTotalExpectedExecutors(resourceProfileToTotalExecs(defaultProfile)) Future.successful(true) } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index 7e1e39c85a183..8c683e85dd5e2 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ +import org.apache.spark.resource.ResourceProfileManager import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{ExecutorKilled, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor @@ -86,10 +87,13 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ private var schedulerBackendUnderTest: KubernetesClusterSchedulerBackend = _ + private val resourceProfileManager = new ResourceProfileManager(sparkConf) + before { MockitoAnnotations.initMocks(this) when(taskScheduler.sc).thenReturn(sc) when(sc.conf).thenReturn(sparkConf) + when(sc.resourceProfileManager).thenReturn(resourceProfileManager) when(sc.env).thenReturn(env) when(env.rpcEnv).thenReturn(rpcEnv) driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index e916125ffdb67..0b447025c8a7a 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalBlockStoreClient +import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc.{RpcEndpointAddress, RpcEndpointRef} import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -181,6 +182,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private var schedulerDriver: SchedulerDriver = _ + private val defaultProfile = sc.resourceProfileManager.defaultResourceProfile + + def newMesosTaskId(): String = { val id = nextMesosTaskId nextMesosTaskId += 1 @@ -595,13 +599,16 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def satisfiesLocality(offerHostname: String): Boolean = { + val hostToLocalTaskCount = + rpHostToLocalTaskCount.getOrElse(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, Map.empty) if (!Utils.isDynamicAllocationEnabled(conf) || hostToLocalTaskCount.isEmpty) { return true } // Check the locality information val currentHosts = slaves.values.filter(_.taskIDs.nonEmpty).map(_.hostname).toSet - val allDesiredHosts = hostToLocalTaskCount.keys.toSet + val allDesiredHosts = hostToLocalTaskCount.map { case (k, v) => k }.toSet + // Try to match locality for hosts which do not have executors yet, to potentially // increase coverage. val remainingHosts = allDesiredHosts -- currentHosts @@ -759,11 +766,14 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( super.applicationId } - override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future.successful { + override def doRequestTotalExecutors( + resourceProfileToTotalExecs: Map[ResourceProfile, Int] + ): Future[Boolean] = Future.successful { // We don't truly know if we can fulfill the full amount of executors // since at coarse grain it depends on the amount of slaves available. - logInfo("Capping the total amount of executors to " + requestedTotal) - executorLimitOption = Some(requestedTotal) + val numExecs = resourceProfileToTotalExecs.getOrElse(defaultProfile, 0) + logInfo("Capping the total amount of executors to " + numExecs) + executorLimitOption = Some(numExecs) // Update the locality wait start time to continue trying for locality. localityWaitStartTimeNs = System.nanoTime() true diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 1876861700fc0..5ab277ed87a72 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -71,8 +71,10 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite offerResources(offers) verifyTaskLaunched(driver, "o1") + val totalExecs = Map(ResourceProfile.getOrCreateDefaultProfile(sparkConf) -> 0) // kills executors - assert(backend.doRequestTotalExecutors(0).futureValue) + val defaultResourceProfile = ResourceProfile.getOrCreateDefaultProfile(sparkConf) + assert(backend.doRequestTotalExecutors(Map(defaultResourceProfile -> 0)).futureValue) assert(backend.doKillExecutors(Seq("0")).futureValue) val taskID0 = createTaskId("0") verify(driver, times(1)).killTask(taskID0) @@ -82,7 +84,7 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite verifyDeclinedOffer(driver, createOfferId("o2")) // Launches a new task when requested executors is positive - backend.doRequestTotalExecutors(2) + backend.doRequestTotalExecutors(Map(defaultResourceProfile -> 2)) offerResources(offers, 2) verifyTaskLaunched(driver, "o2") } @@ -635,7 +637,12 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(backend.getExecutorIds().isEmpty) - backend.requestTotalExecutors(2, 2, Map("hosts10" -> 1, "hosts11" -> 1)) + val defaultProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID + val defaultProf = ResourceProfile.getOrCreateDefaultProfile(sparkConf) + backend.requestTotalExecutors( + Map(defaultProfileId -> 2), + Map(defaultProfileId -> 2), + Map(defaultProfileId -> Map("hosts10" -> 1, "hosts11" -> 1))) // Offer non-local resources, which should be rejected offerResourcesAndVerify(1, false) @@ -651,7 +658,11 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite offerResourcesAndVerify(1, true) // Update total executors - backend.requestTotalExecutors(3, 3, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1)) + backend.requestTotalExecutors( + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> 3), + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> 2), + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> + Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1))) // Offer non-local resources, which should be rejected offerResourcesAndVerify(3, false) @@ -660,8 +671,11 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite Thread.sleep(2000) // Update total executors - backend.requestTotalExecutors(4, 4, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1, - "hosts13" -> 1)) + backend.requestTotalExecutors( + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> 4), + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> 4), + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> + Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1, "hosts13" -> 1))) // Offer non-local resources, which should be rejected offerResourcesAndVerify(3, false) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 471ee58d05cb8..f8bbc39c8bcc5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -27,13 +27,13 @@ import scala.util.{Failure, Success} import scala.util.control.NonFatal import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} -import org.eclipse.jetty.servlet.{FilterHolder, FilterMapping} import org.apache.spark.SparkContext import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.internal.config.UI._ +import org.apache.spark.resource.ResourceProfile import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -92,7 +92,7 @@ private[spark] abstract class YarnSchedulerBackend( try { // SPARK-12009: To prevent Yarn allocator from requesting backup for the executors which // was Stopped by SchedulerBackend. - requestTotalExecutors(0, 0, Map.empty) + requestTotalExecutors(Map.empty, Map.empty, Map.empty) super.stop() } finally { stopped.set(true) @@ -123,21 +123,28 @@ private[spark] abstract class YarnSchedulerBackend( } } - private[cluster] def prepareRequestExecutors(requestedTotal: Int): RequestExecutors = { + private[cluster] def prepareRequestExecutors( + resourceProfileToTotalExecs: Map[ResourceProfile, Int]): RequestExecutors = { val nodeBlacklist: Set[String] = scheduler.nodeBlacklist() // For locality preferences, ignore preferences for nodes that are blacklisted - val filteredHostToLocalTaskCount = - hostToLocalTaskCount.filter { case (k, v) => !nodeBlacklist.contains(k) } - RequestExecutors(requestedTotal, localityAwareTasks, filteredHostToLocalTaskCount, - nodeBlacklist) + val filteredRPHostToLocalTaskCount = rpHostToLocalTaskCount.map { case (rpid, v) => + (rpid, v.filter { case (host, count) => !nodeBlacklist.contains(host) }) + } + // TODO - default everything to default profile until YARN pieces + val defaultProf = ResourceProfile.getOrCreateDefaultProfile(conf) + val hostToLocalTaskCount = filteredRPHostToLocalTaskCount.getOrElse(defaultProf.id, Map.empty) + val localityAwareTasks = numLocalityAwareTasksPerResourceProfileId.getOrElse(defaultProf.id, 0) + val numExecutors = resourceProfileToTotalExecs.getOrElse(defaultProf, 0) + RequestExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount, nodeBlacklist) } /** * Request executors from the ApplicationMaster by specifying the total number desired. * This includes executors already pending or running. */ - override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { - yarnSchedulerEndpointRef.ask[Boolean](prepareRequestExecutors(requestedTotal)) + override def doRequestTotalExecutors( + resourceProfileToTotalExecs: Map[ResourceProfile, Int]): Future[Boolean] = { + yarnSchedulerEndpointRef.ask[Boolean](prepareRequestExecutors(resourceProfileToTotalExecs)) } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala index a87820b1528ad..c0c6fff5130bb 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala @@ -24,6 +24,7 @@ import org.mockito.Mockito.when import org.scalatestplus.mockito.MockitoSugar import org.apache.spark._ +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.serializer.JavaSerializer import org.apache.spark.ui.TestFilter @@ -51,7 +52,8 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc private class TestYarnSchedulerBackend(scheduler: TaskSchedulerImpl, sc: SparkContext) extends YarnSchedulerBackend(scheduler, sc) { def setHostToLocalTaskCount(hostToLocalTaskCount: Map[String, Int]): Unit = { - this.hostToLocalTaskCount = hostToLocalTaskCount + this.rpHostToLocalTaskCount = Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> + hostToLocalTaskCount) } } @@ -72,7 +74,8 @@ class YarnSchedulerBackendSuite extends SparkFunSuite with MockitoSugar with Loc } { yarnSchedulerBackendExtended.setHostToLocalTaskCount(hostToLocalCount) sched.setNodeBlacklist(blacklist) - val req = yarnSchedulerBackendExtended.prepareRequestExecutors(numRequested) + val numReq = Map(ResourceProfile.getOrCreateDefaultProfile(sc.getConf) -> numRequested) + val req = yarnSchedulerBackendExtended.prepareRequestExecutors(numReq) assert(req.requestedTotal === numRequested) assert(req.nodeBlacklist === blacklist) assert(req.hostToLocalTaskCount.keySet.intersect(blacklist).isEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala index a88abc8209a88..c09ff51ecaff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala @@ -27,6 +27,7 @@ import org.apache.spark.{SparkConf, TaskState} import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config.Status._ +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetricInfo @@ -89,7 +90,8 @@ object MetricsAggregationBenchmark extends BenchmarkBase { val taskEventsTime = (0 until numStages).map { _ => val stageInfo = new StageInfo(idgen.incrementAndGet(), 0, getClass().getName(), - numTasks, Nil, Nil, getClass().getName()) + numTasks, Nil, Nil, getClass().getName(), + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) val jobId = idgen.incrementAndGet() val jobStart = SparkListenerJobStart( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index 55b551d0af078..fdfd392a224cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config import org.apache.spark.internal.config.Status._ import org.apache.spark.rdd.RDD +import org.apache.spark.resource.ResourceProfile import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.InternalRow @@ -86,7 +87,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils name = "", rddInfos = Nil, parentIds = Nil, - details = "") + details = "", + resourceProfileId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID) } private def createTaskInfo( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala index e85a3b9009c32..58bd56c591d04 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala @@ -23,6 +23,7 @@ import scala.util.Random import org.apache.spark.{ExecutorAllocationClient, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Streaming._ +import org.apache.spark.resource.ResourceProfile import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, Utils} @@ -111,7 +112,11 @@ private[streaming] class ExecutorAllocationManager( logDebug(s"Executors (${allExecIds.size}) = ${allExecIds}") val targetTotalExecutors = math.max(math.min(maxNumExecutors, allExecIds.size + numNewExecutors), minNumExecutors) - client.requestTotalExecutors(targetTotalExecutors, 0, Map.empty) + // Just map the targetTotalExecutors to the default ResourceProfile + client.requestTotalExecutors( + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> targetTotalExecutors), + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> 0), + Map.empty) logInfo(s"Requested total $targetTotalExecutors executors") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala index 9121da4b9b673..65efa10bfcf92 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala @@ -27,6 +27,7 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{ExecutorAllocationClient, SparkConf} import org.apache.spark.internal.config.{DYN_ALLOCATION_ENABLED, DYN_ALLOCATION_TESTING} import org.apache.spark.internal.config.Streaming._ +import org.apache.spark.resource.ResourceProfile import org.apache.spark.streaming.{DummyInputDStream, Seconds, StreamingContext, TestSuiteBase} import org.apache.spark.util.{ManualClock, Utils} @@ -71,10 +72,15 @@ class ExecutorAllocationManagerSuite extends TestSuiteBase if (expectedRequestedTotalExecs.nonEmpty) { require(expectedRequestedTotalExecs.get > 0) verify(allocationClient, times(1)).requestTotalExecutors( - meq(expectedRequestedTotalExecs.get), meq(0), meq(Map.empty)) + meq(Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> + expectedRequestedTotalExecs.get)), + meq(Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> 0)), + meq(Map.empty)) } else { - verify(allocationClient, never).requestTotalExecutors(0, 0, Map.empty) - } + verify(allocationClient, never).requestTotalExecutors( + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> 0), + Map(ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID -> 0), + Map.empty)} } /** Verify that a particular executor was killed */ @@ -139,8 +145,11 @@ class ExecutorAllocationManagerSuite extends TestSuiteBase reset(allocationClient) when(allocationClient.getExecutorIds()).thenReturn((1 to numExecs).map(_.toString)) requestExecutors(allocationManager, numNewExecs) - verify(allocationClient, times(1)).requestTotalExecutors( - meq(expectedRequestedTotalExecs), meq(0), meq(Map.empty)) + val defaultProfId = ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID + verify(allocationClient, times(1)). + requestTotalExecutors( + meq(Map(defaultProfId -> expectedRequestedTotalExecs)), + meq(Map(defaultProfId -> 0)), meq(Map.empty)) } withAllocationManager(numReceivers = 1) { case (_, allocationManager) => From 926e3a1efe9e142804fcbf52146b22700640ae1b Mon Sep 17 00:00:00 2001 From: iRakson Date: Thu, 13 Feb 2020 12:23:40 +0800 Subject: [PATCH 030/185] [SPARK-30790] The dataType of map() should be map ### What changes were proposed in this pull request? `spark.sql("select map()")` returns {}. After these changes it will return map ### Why are the changes needed? After changes introduced due to #27521, it is important to maintain consistency while using map(). ### Does this PR introduce any user-facing change? Yes. Now map() will give map instead of {}. ### How was this patch tested? UT added. Migration guide updated as well Closes #27542 from iRakson/SPARK-30790. Authored-by: iRakson Signed-off-by: Wenchen Fan --- docs/sql-migration-guide.md | 2 +- .../expressions/complexTypeCreator.scala | 14 ++++++++--- .../catalyst/util/ArrayBasedMapBuilder.scala | 5 ++-- .../apache/spark/sql/internal/SQLConf.scala | 10 ++++---- .../spark/sql/DataFrameFunctionsSuite.scala | 25 +++++++++++++------ 5 files changed, 36 insertions(+), 20 deletions(-) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index f98fab5b4c56b..46b741687363f 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -216,7 +216,7 @@ license: | - Since Spark 3.0, the `size` function returns `NULL` for the `NULL` input. In Spark version 2.4 and earlier, this function gives `-1` for the same input. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.sizeOfNull` to `true`. - - Since Spark 3.0, when the `array` function is called without any parameters, it returns an empty array of `NullType`. In Spark version 2.4 and earlier, it returns an empty array of string type. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.arrayDefaultToStringType.enabled` to `true`. + - Since Spark 3.0, when the `array`/`map` function is called without any parameters, it returns an empty collection with `NullType` as element type. In Spark version 2.4 and earlier, it returns an empty collection with `StringType` as element type. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.createEmptyCollectionUsingStringType` to `true`. - Since Spark 3.0, the interval literal syntax does not allow multiple from-to units anymore. For example, `SELECT INTERVAL '1-1' YEAR TO MONTH '2-2' YEAR TO MONTH'` throws parser exception. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 7335e305bfe55..4bd85d304ded2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -46,7 +46,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } private val defaultElementType: DataType = { - if (SQLConf.get.getConf(SQLConf.LEGACY_ARRAY_DEFAULT_TO_STRING)) { + if (SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) { StringType } else { NullType @@ -145,6 +145,14 @@ case class CreateMap(children: Seq[Expression]) extends Expression { lazy val keys = children.indices.filter(_ % 2 == 0).map(children) lazy val values = children.indices.filter(_ % 2 != 0).map(children) + private val defaultElementType: DataType = { + if (SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) { + StringType + } else { + NullType + } + } + override def foldable: Boolean = children.forall(_.foldable) override def checkInputDataTypes(): TypeCheckResult = { @@ -167,9 +175,9 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override lazy val dataType: MapType = { MapType( keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) - .getOrElse(StringType), + .getOrElse(defaultElementType), valueType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(values.map(_.dataType)) - .getOrElse(StringType), + .getOrElse(defaultElementType), valueContainsNull = values.exists(_.nullable)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index 98934368205ec..37d65309e2b89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -29,12 +29,11 @@ import org.apache.spark.unsafe.array.ByteArrayMethods */ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Serializable { assert(!keyType.existsRecursively(_.isInstanceOf[MapType]), "key of map cannot be/contain map") - assert(keyType != NullType, "map key cannot be null type.") private lazy val keyToIndex = keyType match { // Binary type data is `byte[]`, which can't use `==` to check equality. - case _: AtomicType | _: CalendarIntervalType if !keyType.isInstanceOf[BinaryType] => - new java.util.HashMap[Any, Int]() + case _: AtomicType | _: CalendarIntervalType | _: NullType + if !keyType.isInstanceOf[BinaryType] => new java.util.HashMap[Any, Int]() case _ => // for complex types, use interpreted ordering to be able to compare unsafe data with safe // data, e.g. UnsafeRow vs GenericInternalRow. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d86f8693e0655..95b5b3afc3933 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2007,12 +2007,12 @@ object SQLConf { .booleanConf .createWithDefault(false) - val LEGACY_ARRAY_DEFAULT_TO_STRING = - buildConf("spark.sql.legacy.arrayDefaultToStringType.enabled") + val LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE = + buildConf("spark.sql.legacy.createEmptyCollectionUsingStringType") .internal() - .doc("When set to true, it returns an empty array of string type when the `array` " + - "function is called without any parameters. Otherwise, it returns an empty " + - "array of `NullType`") + .doc("When set to true, Spark returns an empty collection with `StringType` as element " + + "type if the `array`/`map` function is called without any parameters. Otherwise, Spark " + + "returns an empty collection with `NullType` as element type.") .booleanConf .createWithDefault(false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 6012678341ccc..f7531ea446015 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3499,13 +3499,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } - test("SPARK-21281 use string types by default if map have no argument") { - val ds = spark.range(1) - var expectedSchema = new StructType() - .add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false) - assert(ds.select(map().as("x")).schema == expectedSchema) - } - test("SPARK-21281 fails if functions have no argument") { val df = Seq(1).toDF("a") @@ -3563,7 +3556,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-29462: Empty array of NullType for array function with no arguments") { Seq((true, StringType), (false, NullType)).foreach { case (arrayDefaultToString, expectedType) => - withSQLConf(SQLConf.LEGACY_ARRAY_DEFAULT_TO_STRING.key -> arrayDefaultToString.toString) { + withSQLConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE.key -> + arrayDefaultToString.toString) { val schema = spark.range(1).select(array()).schema assert(schema.nonEmpty && schema.head.dataType.isInstanceOf[ArrayType]) val actualType = schema.head.dataType.asInstanceOf[ArrayType].elementType @@ -3571,6 +3565,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-30790: Empty map with NullType as key/value type for map function with no argument") { + Seq((true, StringType), (false, NullType)).foreach { + case (mapDefaultToString, expectedType) => + withSQLConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE.key -> + mapDefaultToString.toString) { + val schema = spark.range(1).select(map()).schema + assert(schema.nonEmpty && schema.head.dataType.isInstanceOf[MapType]) + val actualKeyType = schema.head.dataType.asInstanceOf[MapType].keyType + val actualValueType = schema.head.dataType.asInstanceOf[MapType].valueType + assert(actualKeyType === expectedType) + assert(actualValueType === expectedType) + } + } + } } object DataFrameFunctionsSuite { From 453d5261b22ebcdd5886e65ab9d0d9857051e76a Mon Sep 17 00:00:00 2001 From: maryannxue Date: Thu, 13 Feb 2020 19:32:38 +0800 Subject: [PATCH 031/185] [SPARK-30528][SQL] Turn off DPP subquery duplication by default ### What changes were proposed in this pull request? This PR adds a config for Dynamic Partition Pruning subquery duplication and turns it off by default due to its potential performance regression. When planning a DPP filter, it seeks to reuse the broadcast exchange relation if the corresponding join is a BHJ with the filter relation being on the build side, otherwise it will either opt out or plan the filter as an un-reusable subquery duplication based on the cost estimate. However, the cost estimate is not accurate and only takes into account the table scan overhead, thus adding an un-reusable subquery duplication DPP filter can sometimes cause perf regression. This PR turns off the subquery duplication DPP filter by: 1. adding a config `spark.sql.optimizer.dynamicPartitionPruning.reuseBroadcastOnly` and setting it `true` by default. 2. removing the existing meaningless config `spark.sql.optimizer.dynamicPartitionPruning.reuseBroadcast` since we always want to reuse broadcast results if possible. ### Why are the changes needed? This is to fix a potential performance regression caused by DPP. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Updated DynamicPartitionPruningSuite to test the new configuration. Closes #27551 from maryannxue/spark-30528. Authored-by: maryannxue Signed-off-by: Wenchen Fan --- .../apache/spark/sql/internal/SQLConf.scala | 12 +- .../sql/dynamicpruning/PartitionPruning.scala | 4 +- .../PlanDynamicPruningFilters.scala | 5 +- .../sql/DynamicPartitionPruningSuite.scala | 183 +++++++----------- .../org/apache/spark/sql/ExplainSuite.scala | 3 +- 5 files changed, 82 insertions(+), 125 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 95b5b3afc3933..2214e03f34f0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -259,11 +259,11 @@ object SQLConf { .doubleConf .createWithDefault(0.5) - val DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST = - buildConf("spark.sql.optimizer.dynamicPartitionPruning.reuseBroadcast") + val DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY = + buildConf("spark.sql.optimizer.dynamicPartitionPruning.reuseBroadcastOnly") .internal() - .doc("When true, dynamic partition pruning will seek to reuse the broadcast results from " + - "a broadcast hash join operation.") + .doc("When true, dynamic partition pruning will only apply when the broadcast exchange of " + + "a broadcast hash join operation can be reused as the dynamic pruning filter.") .booleanConf .createWithDefault(true) @@ -2303,8 +2303,8 @@ class SQLConf extends Serializable with Logging { def dynamicPartitionPruningFallbackFilterRatio: Double = getConf(DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO) - def dynamicPartitionPruningReuseBroadcast: Boolean = - getConf(DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST) + def dynamicPartitionPruningReuseBroadcastOnly: Boolean = + getConf(DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY) def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PartitionPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PartitionPruning.scala index 48ba8618f272e..28f8f49d2ce44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PartitionPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PartitionPruning.scala @@ -86,7 +86,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { filteringPlan: LogicalPlan, joinKeys: Seq[Expression], hasBenefit: Boolean): LogicalPlan = { - val reuseEnabled = SQLConf.get.dynamicPartitionPruningReuseBroadcast + val reuseEnabled = SQLConf.get.exchangeReuseEnabled val index = joinKeys.indexOf(filteringKey) if (hasBenefit || reuseEnabled) { // insert a DynamicPruning wrapper to identify the subquery during query planning @@ -96,7 +96,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper { filteringPlan, joinKeys, index, - !hasBenefit), + !hasBenefit || SQLConf.get.dynamicPartitionPruningReuseBroadcastOnly), pruningPlan) } else { // abort dynamic partition pruning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PlanDynamicPruningFilters.scala index 1398dc049dd99..be00f728aa3ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/dynamicpruning/PlanDynamicPruningFilters.scala @@ -36,9 +36,6 @@ import org.apache.spark.sql.internal.SQLConf case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[SparkPlan] with PredicateHelper { - private def reuseBroadcast: Boolean = - SQLConf.get.dynamicPartitionPruningReuseBroadcast && SQLConf.get.exchangeReuseEnabled - /** * Identify the shape in which keys of a given plan are broadcasted. */ @@ -59,7 +56,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) sparkSession, sparkSession.sessionState.planner, buildPlan) // Using `sparkPlan` is a little hacky as it is based on the assumption that this rule is // the first to be applied (apart from `InsertAdaptiveSparkPlan`). - val canReuseExchange = reuseBroadcast && buildKeys.nonEmpty && + val canReuseExchange = SQLConf.get.exchangeReuseEnabled && buildKeys.nonEmpty && plan.find { case BroadcastHashJoinExec(_, _, _, BuildLeft, _, left, _) => left.sameResult(sparkPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index e1f9bcc4e008d..f7b51d6f4c8ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -239,7 +239,8 @@ class DynamicPartitionPruningSuite */ test("simple inner join triggers DPP with mock-up tables") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { withTable("df1", "df2") { spark.range(1000) .select(col("id"), col("id").as("k")) @@ -271,7 +272,8 @@ class DynamicPartitionPruningSuite */ test("self-join on a partitioned table should not trigger DPP") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { withTable("fact") { sql( s""" @@ -302,7 +304,8 @@ class DynamicPartitionPruningSuite */ test("static scan metrics") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { withTable("fact", "dim") { spark.range(10) .map { x => Tuple3(x, x + 1, 0) } @@ -370,7 +373,8 @@ class DynamicPartitionPruningSuite test("DPP should not be rewritten as an existential join") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "1.5", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( s""" |SELECT * FROM product p WHERE p.store_id NOT IN @@ -395,7 +399,7 @@ class DynamicPartitionPruningSuite */ test("DPP triggers only for certain types of query") { withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false") { Given("dynamic partition pruning disabled") withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "false") { val df = sql( @@ -433,7 +437,8 @@ class DynamicPartitionPruningSuite } Given("left-semi join with partition column on the left side") - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( """ |SELECT * FROM fact_sk f @@ -457,7 +462,8 @@ class DynamicPartitionPruningSuite } Given("right outer join with partition column on the left side") - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( """ |SELECT * FROM fact_sk f RIGHT OUTER JOIN dim_store s @@ -474,7 +480,8 @@ class DynamicPartitionPruningSuite */ test("filtering ratio policy fallback") { withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { Given("no stats and selective predicate") withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true") { @@ -543,7 +550,8 @@ class DynamicPartitionPruningSuite */ test("filtering ratio policy with stats when the broadcast pruning is disabled") { withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { Given("disabling the use of stats in the DPP heuristic") withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false") { @@ -613,10 +621,7 @@ class DynamicPartitionPruningSuite test("partition pruning in broadcast hash joins with non-deterministic probe part") { Given("alias with simple join condition, and non-deterministic query") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -630,10 +635,7 @@ class DynamicPartitionPruningSuite } Given("alias over multiple sub-queries with simple join condition") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -651,10 +653,7 @@ class DynamicPartitionPruningSuite test("partition pruning in broadcast hash joins with aliases") { Given("alias with simple join condition, using attribute names only") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -674,10 +673,7 @@ class DynamicPartitionPruningSuite } Given("alias with expr as join condition") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -697,10 +693,7 @@ class DynamicPartitionPruningSuite } Given("alias over multiple sub-queries with simple join condition") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -722,10 +715,7 @@ class DynamicPartitionPruningSuite } Given("alias over multiple sub-queries with simple join condition") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid_d as pid, f.sid_d as sid FROM @@ -754,10 +744,8 @@ class DynamicPartitionPruningSuite test("partition pruning in broadcast hash joins") { Given("disable broadcast pruning and disable subquery duplication") withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( """ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f @@ -777,9 +765,10 @@ class DynamicPartitionPruningSuite Given("disable reuse broadcast results and enable subquery duplication") withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0.5") { + SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0.5", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { val df = sql( """ |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f @@ -798,52 +787,47 @@ class DynamicPartitionPruningSuite } Given("enable reuse broadcast results and disable query duplication") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { - val df = sql( - """ - |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f - |JOIN dim_stats s - |ON f.store_id = s.store_id WHERE s.country = 'DE' - """.stripMargin) + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f + |JOIN dim_stats s + |ON f.store_id = s.store_id WHERE s.country = 'DE' + """.stripMargin) - checkPartitionPruningPredicate(df, false, true) + checkPartitionPruningPredicate(df, false, true) - checkAnswer(df, - Row(1030, 2, 10, 3) :: - Row(1040, 2, 50, 3) :: - Row(1050, 2, 50, 3) :: - Row(1060, 2, 50, 3) :: Nil - ) + checkAnswer(df, + Row(1030, 2, 10, 3) :: + Row(1040, 2, 50, 3) :: + Row(1050, 2, 50, 3) :: + Row(1060, 2, 50, 3) :: Nil + ) } Given("disable broadcast hash join and disable query duplication") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { - val df = sql( - """ - |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f - |JOIN dim_stats s - |ON f.store_id = s.store_id WHERE s.country = 'DE' - """.stripMargin) + withSQLConf( + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = sql( + """ + |SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f + |JOIN dim_stats s + |ON f.store_id = s.store_id WHERE s.country = 'DE' + """.stripMargin) - checkPartitionPruningPredicate(df, false, false) + checkPartitionPruningPredicate(df, false, false) - checkAnswer(df, - Row(1030, 2, 10, 3) :: - Row(1040, 2, 50, 3) :: - Row(1050, 2, 50, 3) :: - Row(1060, 2, 50, 3) :: Nil - ) + checkAnswer(df, + Row(1030, 2, 10, 3) :: + Row(1040, 2, 50, 3) :: + Row(1050, 2, 50, 3) :: + Row(1060, 2, 50, 3) :: Nil + ) } Given("disable broadcast hash join and enable query duplication") - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "true") { val df = sql( @@ -865,9 +849,7 @@ class DynamicPartitionPruningSuite } test("broadcast a single key in a HashedRelation") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(100).select( $"id", @@ -925,9 +907,7 @@ class DynamicPartitionPruningSuite } test("broadcast multiple keys in a LongHashedRelation") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(100).select( $"id", @@ -962,9 +942,7 @@ class DynamicPartitionPruningSuite } test("broadcast multiple keys in an UnsafeHashedRelation") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(100).select( $"id", @@ -999,9 +977,7 @@ class DynamicPartitionPruningSuite } test("different broadcast subqueries with identical children") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(100).select( $"id", @@ -1073,7 +1049,7 @@ class DynamicPartitionPruningSuite } test("avoid reordering broadcast join keys to match input hash partitioning") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withTable("large", "dimTwo", "dimThree") { spark.range(100).select( @@ -1123,9 +1099,7 @@ class DynamicPartitionPruningSuite * duplicated partitioning keys, also used to uniquely identify the dynamic pruning filters. */ test("dynamic partition pruning ambiguity issue across nested joins") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("store", "date", "item") { spark.range(500) .select((($"id" + 30) % 50).as("ss_item_sk"), @@ -1163,9 +1137,7 @@ class DynamicPartitionPruningSuite } test("cleanup any DPP filter that isn't pushed down due to expression id clashes") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { withTable("fact", "dim") { spark.range(1000).select($"id".as("A"), $"id".as("AA")) .write.partitionBy("A").format(tableFormat).mode("overwrite").saveAsTable("fact") @@ -1186,10 +1158,7 @@ class DynamicPartitionPruningSuite } test("cleanup any DPP filter that isn't pushed down due to non-determinism") { - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.pid, f.sid FROM @@ -1204,9 +1173,7 @@ class DynamicPartitionPruningSuite } test("join key with multiple references on the filtering plan") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0", + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { // when enable AQE, the reusedExchange is inserted when executed. withTable("fact", "dim") { @@ -1240,9 +1207,7 @@ class DynamicPartitionPruningSuite } test("Make sure dynamic pruning works on uncorrelated queries") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT d.store_id, @@ -1266,10 +1231,7 @@ class DynamicPartitionPruningSuite test("Plan broadcast pruning only when the broadcast can be reused") { Given("dynamic pruning filter on the build side") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT f.date_id, f.store_id, f.product_id, f.units_sold FROM fact_np f @@ -1288,10 +1250,7 @@ class DynamicPartitionPruningSuite } Given("dynamic pruning filter on the probe side") - withSQLConf( - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_USE_STATS.key -> "false", - SQLConf.DYNAMIC_PARTITION_PRUNING_FALLBACK_FILTER_RATIO.key -> "0") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { val df = sql( """ |SELECT /*+ BROADCAST(f)*/ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index d9f4d6d5132ae..b591705274110 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -239,7 +239,8 @@ class ExplainSuite extends QueryTest with SharedSparkSession { test("explain formatted - check presence of subquery in case of DPP") { withTable("df1", "df2") { withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", - SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST.key -> "false") { + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { withTable("df1", "df2") { spark.range(1000).select(col("id"), col("id").as("k")) .write From a6b4b914f2d2b873b0e9b9d446fda69dc74c3cf8 Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 13 Feb 2020 20:13:36 +0800 Subject: [PATCH 032/185] [SPARK-30613][SQL] Support Hive style REPLACE COLUMNS syntax ### What changes were proposed in this pull request? This PR proposes to support Hive-style `ALTER TABLE ... REPLACE COLUMNS ...` as described in https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-Add/ReplaceColumns The user now can do the following: ```SQL CREATE TABLE t (col1 int, col2 int) USING Foo; ALTER TABLE t REPLACE COLUMNS (col2 string COMMENT 'comment2', col3 int COMMENT 'comment3'); ``` , which drops the existing columns `col1` and `col2`, and add new columns `col2` and `col3`. ### Why are the changes needed? This is a new DDL statement. Spark currently supports the Hive-style `ALTER TABLE ... CHANGE COLUMN ...`, so this new addition can be useful. ### Does this PR introduce any user-facing change? Yes, adding a new DDL statement. ### How was this patch tested? More tests to be added. Closes #27482 from imback82/replace_cols. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 3 ++ .../sql/catalyst/analysis/CheckAnalysis.scala | 12 ++++- .../catalyst/analysis/ResolveCatalogs.scala | 21 ++++++++ .../sql/catalyst/parser/AstBuilder.scala | 21 ++++++++ .../catalyst/plans/logical/statements.scala | 4 ++ .../sql/catalyst/parser/DDLParserSuite.scala | 48 ++++++++++++++++++- .../analysis/ResolveSessionCatalog.scala | 23 +++++++++ .../spark/sql/connector/AlterTableTests.scala | 15 ++++++ 8 files changed, 145 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 563ef69b3b8ae..2bc71476aba02 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -164,6 +164,9 @@ statement | ALTER TABLE table=multipartIdentifier partitionSpec? CHANGE COLUMN? colName=multipartIdentifier colType colPosition? #hiveChangeColumn + | ALTER TABLE table=multipartIdentifier partitionSpec? + REPLACE COLUMNS + '(' columns=qualifiedColTypeWithPositionList ')' #hiveReplaceColumns | ALTER TABLE multipartIdentifier (partitionSpec)? SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe | ALTER TABLE multipartIdentifier (partitionSpec)? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index e769e038c960f..67c509ed98245 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -473,9 +473,15 @@ trait CheckAnalysis extends PredicateHelper { } } + val colsToDelete = mutable.Set.empty[Seq[String]] + alter.changes.foreach { case add: AddColumn => - checkColumnNotExists("add", add.fieldNames(), table.schema) + // If a column to add is a part of columns to delete, we don't need to check + // if column already exists - applies to REPLACE COLUMNS scenario. + if (!colsToDelete.contains(add.fieldNames())) { + checkColumnNotExists("add", add.fieldNames(), table.schema) + } val parent = findParentStruct("add", add.fieldNames()) positionArgumentExists(add.position(), parent) TypeUtils.failWithIntervalType(add.dataType()) @@ -526,6 +532,10 @@ trait CheckAnalysis extends PredicateHelper { findField("update", update.fieldNames) case delete: DeleteColumn => findField("delete", delete.fieldNames) + // REPLACE COLUMNS has deletes followed by adds. Remember the deleted columns + // so that add operations do not fail when the columns to add exist and they + // are to be deleted. + colsToDelete += delete.fieldNames case _ => // no validation needed for set and remove property } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index 88a3c0a73a10b..96558410d4004 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -44,6 +44,27 @@ class ResolveCatalogs(val catalogManager: CatalogManager) } createAlterTable(nameParts, catalog, tbl, changes) + case AlterTableReplaceColumnsStatement( + nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => + val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { + case Some(table) => + // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. + val deleteChanges = table.schema.fieldNames.map { name => + TableChange.deleteColumn(Array(name)) + } + val addChanges = cols.map { col => + TableChange.addColumn( + col.name.toArray, + col.dataType, + col.nullable, + col.comment.orNull, + col.position.orNull) + } + deleteChanges ++ addChanges + case None => Seq() + } + createAlterTable(nameParts, catalog, tbl, changes) + case a @ AlterTableAlterColumnStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _) => val colName = a.column.toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 62e568587fcc6..b3541a7f7374d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3030,6 +3030,27 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging position = Option(ctx.colPosition).map(typedVisit[ColumnPosition])) } + override def visitHiveReplaceColumns( + ctx: HiveReplaceColumnsContext): LogicalPlan = withOrigin(ctx) { + if (ctx.partitionSpec != null) { + operationNotAllowed("ALTER TABLE table PARTITION partition_spec REPLACE COLUMNS", ctx) + } + AlterTableReplaceColumnsStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + ctx.columns.qualifiedColTypeWithPosition.asScala.map { colType => + if (colType.NULL != null) { + throw new AnalysisException( + "NOT NULL is not supported in Hive-style REPLACE COLUMNS") + } + if (colType.colPosition != null) { + throw new AnalysisException( + "Column position is not supported in Hive-style REPLACE COLUMNS") + } + typedVisit[QualifiedColType](colType) + } + ) + } + /** * Parse a [[AlterTableDropColumnsStatement]] command. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 1e6b67bf78b70..6731214d3842d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -156,6 +156,10 @@ case class AlterTableAddColumnsStatement( tableName: Seq[String], columnsToAdd: Seq[QualifiedColType]) extends ParsedStatement +case class AlterTableReplaceColumnsStatement( + tableName: Seq[String], + columnsToAdd: Seq[QualifiedColType]) extends ParsedStatement + /** * ALTER TABLE ... CHANGE COLUMN command, as parsed from SQL. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index bc7b51f25b20d..049f56c8c9ce1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -699,7 +699,7 @@ class DDLParserSuite extends AnalysisTest { } } - test("alter table: hive style") { + test("alter table: hive style change column") { val sql1 = "ALTER TABLE table_name CHANGE COLUMN a.b.c c INT" val sql2 = "ALTER TABLE table_name CHANGE COLUMN a.b.c c INT COMMENT 'new_comment'" val sql3 = "ALTER TABLE table_name CHANGE COLUMN a.b.c c INT AFTER other_col" @@ -742,6 +742,52 @@ class DDLParserSuite extends AnalysisTest { intercept("ALTER TABLE table_name PARTITION (a='1') CHANGE COLUMN a.b.c c INT") } + test("alter table: hive style replace columns") { + val sql1 = "ALTER TABLE table_name REPLACE COLUMNS (x string)" + val sql2 = "ALTER TABLE table_name REPLACE COLUMNS (x string COMMENT 'x1')" + val sql3 = "ALTER TABLE table_name REPLACE COLUMNS (x string COMMENT 'x1', y int)" + val sql4 = "ALTER TABLE table_name REPLACE COLUMNS (x string COMMENT 'x1', y int COMMENT 'y1')" + + comparePlans( + parsePlan(sql1), + AlterTableReplaceColumnsStatement( + Seq("table_name"), + Seq(QualifiedColType(Seq("x"), StringType, true, None, None)))) + + comparePlans( + parsePlan(sql2), + AlterTableReplaceColumnsStatement( + Seq("table_name"), + Seq(QualifiedColType(Seq("x"), StringType, true, Some("x1"), None)))) + + comparePlans( + parsePlan(sql3), + AlterTableReplaceColumnsStatement( + Seq("table_name"), + Seq( + QualifiedColType(Seq("x"), StringType, true, Some("x1"), None), + QualifiedColType(Seq("y"), IntegerType, true, None, None) + ))) + + comparePlans( + parsePlan(sql4), + AlterTableReplaceColumnsStatement( + Seq("table_name"), + Seq( + QualifiedColType(Seq("x"), StringType, true, Some("x1"), None), + QualifiedColType(Seq("y"), IntegerType, true, Some("y1"), None) + ))) + + intercept("ALTER TABLE table_name PARTITION (a='1') REPLACE COLUMNS (x string)", + "Operation not allowed: ALTER TABLE table PARTITION partition_spec REPLACE COLUMNS") + + intercept("ALTER TABLE table_name REPLACE COLUMNS (x string NOT NULL)", + "NOT NULL is not supported in Hive-style REPLACE COLUMNS") + + intercept("ALTER TABLE table_name REPLACE COLUMNS (x string FIRST)", + "Column position is not supported in Hive-style REPLACE COLUMNS") + } + test("alter table/view: rename table/view") { comparePlans( parsePlan("ALTER TABLE a.b.c RENAME TO x.y.z"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 77d549c28aae5..adeb2164eff63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -69,6 +69,29 @@ class ResolveSessionCatalog( createAlterTable(nameParts, catalog, tbl, changes) } + case AlterTableReplaceColumnsStatement( + nameParts @ SessionCatalogAndTable(catalog, tbl), cols) => + val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { + case Some(_: V1Table) => + throw new AnalysisException("REPLACE COLUMNS is only supported with v2 tables.") + case Some(table) => + // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. + val deleteChanges = table.schema.fieldNames.map { name => + TableChange.deleteColumn(Array(name)) + } + val addChanges = cols.map { col => + TableChange.addColumn( + col.name.toArray, + col.dataType, + col.nullable, + col.comment.orNull, + col.position.orNull) + } + deleteChanges ++ addChanges + case None => Seq() // Unresolved table will be handled in CheckAnalysis. + } + createAlterTable(nameParts, catalog, tbl, changes) + case a @ AlterTableAlterColumnStatement( nameParts @ SessionCatalogAndTable(catalog, tbl), _, _, _, _, _) => loadTable(catalog, tbl.asIdentifier).collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 420cb01d766a0..96fe301b512ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -1060,4 +1060,19 @@ trait AlterTableTests extends SharedSparkSession { assert(updated.properties === withDefaultOwnership(Map("provider" -> v2Format)).asJava) } } + + test("AlterTable: replace columns") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql(s"CREATE TABLE $t (col1 int, col2 int COMMENT 'c2') USING $v2Format") + sql(s"ALTER TABLE $t REPLACE COLUMNS (col2 string, col3 int COMMENT 'c3')") + + val table = getTableMetadata(t) + + assert(table.name === fullTableName(t)) + assert(table.schema === StructType(Seq( + StructField("col2", StringType), + StructField("col3", IntegerType).withComment("c3")))) + } + } } From 04604b9899cc43a9726d671061ff305912fdb85f Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 13 Feb 2020 22:06:24 +0800 Subject: [PATCH 033/185] [SPARK-30758][SQL][TESTS] Improve bracketed comments tests ### What changes were proposed in this pull request? Although Spark SQL support bracketed comments, but `SQLQueryTestSuite` can't treat bracketed comments well and lead to generated golden files can't display bracketed comments well. This PR will improve the treatment of bracketed comments and add three test case in `PlanParserSuite`. Spark SQL can't support nested bracketed comments and https://github.com/apache/spark/pull/27495 used to support it. ### Why are the changes needed? Golden files can't display well. ### Does this PR introduce any user-facing change? No ### How was this patch tested? New UT. Closes #27481 from beliefer/ansi-brancket-comments. Authored-by: beliefer Signed-off-by: Wenchen Fan --- .../sql-tests/inputs/postgreSQL/comments.sql | 6 +- .../results/postgreSQL/comments.sql.out | 137 ++++-------------- .../apache/spark/sql/SQLQueryTestSuite.scala | 51 ++++++- 3 files changed, 78 insertions(+), 116 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/comments.sql b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/comments.sql index 6725ce45e72a5..1a454179ef79f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/comments.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/postgreSQL/comments.sql @@ -11,17 +11,19 @@ SELECT /* embedded single line */ 'embedded' AS `second`; SELECT /* both embedded and trailing single line */ 'both' AS third; -- trailing single line SELECT 'before multi-line' AS fourth; +--QUERY-DELIMITER-START -- [SPARK-28880] ANSI SQL: Bracketed comments /* This is an example of SQL which should not execute: * select 'multi-line'; */ SELECT 'after multi-line' AS fifth; +--QUERY-DELIMITER-END -- [SPARK-28880] ANSI SQL: Bracketed comments -- -- Nested comments -- - +--QUERY-DELIMITER-START /* SELECT 'trailing' as x1; -- inside block comment */ @@ -44,5 +46,5 @@ Hoo boy. Still two deep... Now just one deep... */ 'deeply nested example' AS sixth; - +--QUERY-DELIMITER-END /* and this is the end of the file */ diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/comments.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/comments.sql.out index 4ea49013a62d1..637c5561bd940 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/comments.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/comments.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 7 -- !query @@ -36,129 +36,32 @@ before multi-line -- !query /* This is an example of SQL which should not execute: - * select 'multi-line' --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -/* This is an example of SQL which should not execute: -^^^ - * select 'multi-line' - - --- !query -*/ + * select 'multi-line'; + */ SELECT 'after multi-line' AS fifth -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -extraneous input '*/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -*/ -^^^ -SELECT 'after multi-line' AS fifth +after multi-line -- !query /* -SELECT 'trailing' as x1 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -/* -^^^ -SELECT 'trailing' as x1 - - --- !query +SELECT 'trailing' as x1; -- inside block comment */ /* This block comment surrounds a query which itself has a block comment... -SELECT /* embedded single line */ 'embedded' AS x2 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '*/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -*/ -^^^ - -/* This block comment surrounds a query which itself has a block comment... -SELECT /* embedded single line */ 'embedded' AS x2 - - --- !query +SELECT /* embedded single line */ 'embedded' AS x2; */ SELECT -- continued after the following block comments... /* Deeply nested comment. This includes a single apostrophe to make sure we aren't decoding this part as a string. -SELECT 'deep nest' AS n1 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -extraneous input '*/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -*/ -^^^ - -SELECT -- continued after the following block comments... -/* Deeply nested comment. - This includes a single apostrophe to make sure we aren't decoding this part as a string. -SELECT 'deep nest' AS n1 - - --- !query +SELECT 'deep nest' AS n1; /* Second level of nesting... -SELECT 'deeper nest' as n2 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -/* Second level of nesting... -^^^ -SELECT 'deeper nest' as n2 - - --- !query +SELECT 'deeper nest' as n2; /* Third level of nesting... -SELECT 'deepest nest' as n3 --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException - -mismatched input '/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) - -== SQL == -/* Third level of nesting... -^^^ -SELECT 'deepest nest' as n3 - - --- !query +SELECT 'deepest nest' as n3; */ Hoo boy. Still two deep... */ @@ -170,11 +73,27 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException -mismatched input '*/' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 1, pos 0) +mismatched input ''embedded'' expecting {'(', 'ADD', 'ALTER', 'ANALYZE', 'CACHE', 'CLEAR', 'COMMENT', 'COMMIT', 'CREATE', 'DELETE', 'DESC', 'DESCRIBE', 'DFS', 'DROP', 'EXPLAIN', 'EXPORT', 'FROM', 'GRANT', 'IMPORT', 'INSERT', 'LIST', 'LOAD', 'LOCK', 'MAP', 'MERGE', 'MSCK', 'REDUCE', 'REFRESH', 'REPLACE', 'RESET', 'REVOKE', 'ROLLBACK', 'SELECT', 'SET', 'SHOW', 'START', 'TABLE', 'TRUNCATE', 'UNCACHE', 'UNLOCK', 'UPDATE', 'USE', 'VALUES', 'WITH'}(line 6, pos 34) == SQL == +/* +SELECT 'trailing' as x1; -- inside block comment +*/ + +/* This block comment surrounds a query which itself has a block comment... +SELECT /* embedded single line */ 'embedded' AS x2; +----------------------------------^^^ +*/ + +SELECT -- continued after the following block comments... +/* Deeply nested comment. + This includes a single apostrophe to make sure we aren't decoding this part as a string. +SELECT 'deep nest' AS n1; +/* Second level of nesting... +SELECT 'deeper nest' as n2; +/* Third level of nesting... +SELECT 'deepest nest' as n3; */ -^^^ Hoo boy. Still two deep... */ Now just one deep... diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 6b9e5bbd3c961..da4727f6a98cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -19,7 +19,9 @@ package org.apache.spark.sql import java.io.File import java.util.{Locale, TimeZone} +import java.util.regex.Pattern +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal import org.apache.spark.{SparkConf, SparkException} @@ -62,7 +64,12 @@ import org.apache.spark.tags.ExtendedSQLTest * }}} * * The format for input files is simple: - * 1. A list of SQL queries separated by semicolon. + * 1. A list of SQL queries separated by semicolons by default. If the semicolon cannot effectively + * separate the SQL queries in the test file(e.g. bracketed comments), please use + * --QUERY-DELIMITER-START and --QUERY-DELIMITER-END. Lines starting with + * --QUERY-DELIMITER-START and --QUERY-DELIMITER-END represent the beginning and end of a query, + * respectively. Code that is not surrounded by lines that begin with --QUERY-DELIMITER-START + * and --QUERY-DELIMITER-END is still separated by semicolons. * 2. Lines starting with -- are treated as comments and ignored. * 3. Lines starting with --SET are used to specify the configs when running this testing file. You * can set multiple configs in one --SET, using comma to separate them. Or you can use multiple @@ -246,9 +253,15 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { /** Run a test case. */ protected def runTest(testCase: TestCase): Unit = { + def splitWithSemicolon(seq: Seq[String]) = { + seq.mkString("\n").split("(?<=[^\\\\]);") + } val input = fileToString(new File(testCase.inputFile)) - val (comments, code) = input.split("\n").partition(_.trim.startsWith("--")) + val (comments, code) = input.split("\n").partition { line => + val newLine = line.trim + newLine.startsWith("--") && !newLine.startsWith("--QUERY-DELIMITER") + } // If `--IMPORT` found, load code from another test case file, then insert them // into the head in this test. @@ -261,10 +274,38 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { } }.flatten + val allCode = importedCode ++ code + val tempQueries = if (allCode.exists(_.trim.startsWith("--QUERY-DELIMITER"))) { + // Although the loop is heavy, only used for bracketed comments test. + val querys = new ArrayBuffer[String] + val otherCodes = new ArrayBuffer[String] + var tempStr = "" + var start = false + for (c <- allCode) { + if (c.trim.startsWith("--QUERY-DELIMITER-START")) { + start = true + querys ++= splitWithSemicolon(otherCodes.toSeq) + otherCodes.clear() + } else if (c.trim.startsWith("--QUERY-DELIMITER-END")) { + start = false + querys += s"\n${tempStr.stripSuffix(";")}" + tempStr = "" + } else if (start) { + tempStr += s"\n$c" + } else { + otherCodes += c + } + } + if (otherCodes.nonEmpty) { + querys ++= splitWithSemicolon(otherCodes.toSeq) + } + querys.toSeq + } else { + splitWithSemicolon(allCode).toSeq + } + // List of SQL queries to run - // note: this is not a robust way to split queries using semicolon, but works for now. - val queries = (importedCode ++ code).mkString("\n").split("(?<=[^\\\\]);") - .map(_.trim).filter(_ != "").toSeq + val queries = tempQueries.map(_.trim).filter(_ != "").toSeq // Fix misplacement when comment is at the end of the query. .map(_.split("\n").filterNot(_.startsWith("--")).mkString("\n")).map(_.trim).filter(_ != "") From fb0e07b08ccaeda50a5121bcb1fab69a1ff749c4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Thu, 13 Feb 2020 22:48:27 +0800 Subject: [PATCH 034/185] [SPARK-29231][SQL] Constraints should be inferred from cast equality constraint ### What changes were proposed in this pull request? This PR add support infer constraints from cast equality constraint. For example: ```scala scala> spark.sql("create table spark_29231_1(c1 bigint, c2 bigint)") res0: org.apache.spark.sql.DataFrame = [] scala> spark.sql("create table spark_29231_2(c1 int, c2 bigint)") res1: org.apache.spark.sql.DataFrame = [] scala> spark.sql("select t1.* from spark_29231_1 t1 join spark_29231_2 t2 on (t1.c1 = t2.c1 and t1.c1 = 1)").explain == Physical Plan == *(2) Project [c1#5L, c2#6L] +- *(2) BroadcastHashJoin [c1#5L], [cast(c1#7 as bigint)], Inner, BuildRight :- *(2) Project [c1#5L, c2#6L] : +- *(2) Filter (isnotnull(c1#5L) AND (c1#5L = 1)) : +- *(2) ColumnarToRow : +- FileScan parquet default.spark_29231_1[c1#5L,c2#6L] Batched: true, DataFilters: [isnotnull(c1#5L), (c1#5L = 1)], Format: Parquet, Location: InMemoryFileIndex[file:/root/spark-3.0.0-preview2-bin-hadoop2.7/spark-warehouse/spark_29231_1], PartitionFilters: [], PushedFilters: [IsNotNull(c1), EqualTo(c1,1)], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#209] +- *(1) Project [c1#7] +- *(1) Filter isnotnull(c1#7) +- *(1) ColumnarToRow +- FileScan parquet default.spark_29231_2[c1#7] Batched: true, DataFilters: [isnotnull(c1#7)], Format: Parquet, Location: InMemoryFileIndex[file:/root/spark-3.0.0-preview2-bin-hadoop2.7/spark-warehouse/spark_29231_2], PartitionFilters: [], PushedFilters: [IsNotNull(c1)], ReadSchema: struct ``` After this PR: ```scala scala> spark.sql("select t1.* from spark_29231_1 t1 join spark_29231_2 t2 on (t1.c1 = t2.c1 and t1.c1 = 1)").explain == Physical Plan == *(2) Project [c1#0L, c2#1L] +- *(2) BroadcastHashJoin [c1#0L], [cast(c1#2 as bigint)], Inner, BuildRight :- *(2) Project [c1#0L, c2#1L] : +- *(2) Filter (isnotnull(c1#0L) AND (c1#0L = 1)) : +- *(2) ColumnarToRow : +- FileScan parquet default.spark_29231_1[c1#0L,c2#1L] Batched: true, DataFilters: [isnotnull(c1#0L), (c1#0L = 1)], Format: Parquet, Location: InMemoryFileIndex[file:/root/opensource/spark/spark-warehouse/spark_29231_1], PartitionFilters: [], PushedFilters: [IsNotNull(c1), EqualTo(c1,1)], ReadSchema: struct +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))), [id=#99] +- *(1) Project [c1#2] +- *(1) Filter ((cast(c1#2 as bigint) = 1) AND isnotnull(c1#2)) +- *(1) ColumnarToRow +- FileScan parquet default.spark_29231_2[c1#2] Batched: true, DataFilters: [(cast(c1#2 as bigint) = 1), isnotnull(c1#2)], Format: Parquet, Location: InMemoryFileIndex[file:/root/opensource/spark/spark-warehouse/spark_29231_2], PartitionFilters: [], PushedFilters: [IsNotNull(c1)], ReadSchema: struct ``` ### Why are the changes needed? Improve query performance. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Unit test. Closes #27252 from wangyum/SPARK-29231. Authored-by: Yuming Wang Signed-off-by: Wenchen Fan --- .../plans/logical/QueryPlanConstraints.scala | 12 +++- .../InferFiltersFromConstraintsSuite.scala | 57 ++++++++++++++++++- 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 1355003358b9f..4c4ec000d0930 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -62,11 +62,17 @@ trait ConstraintHelper { */ def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = { var inferredConstraints = Set.empty[Expression] - constraints.foreach { + // IsNotNull should be constructed by `constructIsNotNullConstraints`. + val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull]) + predicates.foreach { case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = constraints - eq + val candidateConstraints = predicates - eq inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) => + inferredConstraints ++= replaceConstraints(predicates - eq, r, l) + case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) => + inferredConstraints ++= replaceConstraints(predicates - eq, l, r) case _ => // No inference } inferredConstraints -- constraints @@ -75,7 +81,7 @@ trait ConstraintHelper { private def replaceConstraints( constraints: Set[Expression], source: Expression, - destination: Attribute): Set[Expression] = constraints.map(_ transform { + destination: Expression): Set[Expression] = constraints.map(_ transform { case e: Expression if e.semanticEquals(source) => destination }) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 974bc781d36ab..79bd573f1d84a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, LongType} class InferFiltersFromConstraintsSuite extends PlanTest { @@ -46,8 +47,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest { y: LogicalPlan, expectedLeft: LogicalPlan, expectedRight: LogicalPlan, - joinType: JoinType) = { - val condition = Some("x.a".attr === "y.a".attr) + joinType: JoinType, + condition: Option[Expression] = Some("x.a".attr === "y.a".attr)) = { val originalQuery = x.join(y, joinType, condition).analyze val correctAnswer = expectedLeft.join(expectedRight, joinType, condition).analyze val optimized = Optimize.execute(originalQuery) @@ -263,4 +264,56 @@ class InferFiltersFromConstraintsSuite extends PlanTest { val y = testRelation.subquery('y) testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) } + + test("Constraints should be inferred from cast equality constraint(filter higher data type)") { + val testRelation1 = LocalRelation('a.int) + val testRelation2 = LocalRelation('b.long) + val originalLeft = testRelation1.subquery('left) + val originalRight = testRelation2.where('b === 1L).subquery('right) + + val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left) + val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right) + + Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), + Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => + testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + } + + Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), + Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => + testConstraintsAfterJoin( + originalLeft, + originalRight, + testRelation1.where(IsNotNull('a)).subquery('left), + right, + Inner, + condition) + } + } + + test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") { + val testRelation1 = LocalRelation('a.int) + val testRelation2 = LocalRelation('b.long) + val originalLeft = testRelation1.where('a === 1).subquery('left) + val originalRight = testRelation2.subquery('right) + + val left = testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left) + val right = testRelation2.where(IsNotNull('b)).subquery('right) + + Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), + Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => + testConstraintsAfterJoin(originalLeft, originalRight, left, right, Inner, condition) + } + + Seq(Some("left.a".attr === "right.b".attr.cast(IntegerType)), + Some("right.b".attr.cast(IntegerType) === "left.a".attr)).foreach { condition => + testConstraintsAfterJoin( + originalLeft, + originalRight, + left, + testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right), + Inner, + condition) + } + } } From 82d0aa37ae521231d8067e473c6ea79a118a115a Mon Sep 17 00:00:00 2001 From: Liang Zhang Date: Thu, 13 Feb 2020 23:55:13 +0800 Subject: [PATCH 035/185] [SPARK-30762] Add dtype=float32 support to vector_to_array UDF ### What changes were proposed in this pull request? In this PR, we add a parameter in the python function vector_to_array(col) that allows converting to a column of arrays of Float (32bits) in scala, which would be mapped to a numpy array of dtype=float32. ### Why are the changes needed? In the downstream ML training, using float32 instead of float64 (default) would allow a larger batch size, i.e., allow more data to fit in the memory. ### Does this PR introduce any user-facing change? Yes. Old: `vector_to_array()` only take one param ``` df.select(vector_to_array("colA"), ...) ``` New: `vector_to_array()` can take an additional optional param: `dtype` = "float32" (or "float64") ``` df.select(vector_to_array("colA", "float32"), ...) ``` ### How was this patch tested? Unit test in scala. doctest in python. Closes #27522 from liangz1/udf-float32. Authored-by: Liang Zhang Signed-off-by: WeichenXu --- .../scala/org/apache/spark/ml/functions.scala | 34 ++++++++++++++++--- .../org/apache/spark/ml/FunctionsSuite.scala | 33 +++++++++++++++--- python/pyspark/ml/functions.py | 27 ++++++++++++--- 3 files changed, 81 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/functions.scala b/mllib/src/main/scala/org/apache/spark/ml/functions.scala index 1faf562c4d896..0f03231079866 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/functions.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/functions.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.Since -import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.linalg.{SparseVector, Vector} import org.apache.spark.mllib.linalg.{Vector => OldVector} import org.apache.spark.sql.Column import org.apache.spark.sql.functions.udf @@ -27,7 +27,6 @@ import org.apache.spark.sql.functions.udf @Since("3.0.0") object functions { // scalastyle:on - private val vectorToArrayUdf = udf { vec: Any => vec match { case v: Vector => v.toArray @@ -39,10 +38,37 @@ object functions { } }.asNonNullable() + private val vectorToArrayFloatUdf = udf { vec: Any => + vec match { + case v: SparseVector => + val data = new Array[Float](v.size) + v.foreachActive { (index, value) => data(index) = value.toFloat } + data + case v: Vector => v.toArray.map(_.toFloat) + case v: OldVector => v.toArray.map(_.toFloat) + case v => throw new IllegalArgumentException( + "function vector_to_array requires a non-null input argument and input type must be " + + "`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " + + s"but got ${ if (v == null) "null" else v.getClass.getName }.") + } + }.asNonNullable() + /** * Converts a column of MLlib sparse/dense vectors into a column of dense arrays. - * + * @param v: the column of MLlib sparse/dense vectors + * @param dtype: the desired underlying data type in the returned array + * @return an array<float> if dtype is float32, or array<double> if dtype is float64 * @since 3.0.0 */ - def vector_to_array(v: Column): Column = vectorToArrayUdf(v) + def vector_to_array(v: Column, dtype: String = "float64"): Column = { + if (dtype == "float64") { + vectorToArrayUdf(v) + } else if (dtype == "float32") { + vectorToArrayFloatUdf(v) + } else { + throw new IllegalArgumentException( + s"Unsupported dtype: $dtype. Valid values: float64, float32." + ) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala index 2f5062c689fc7..3dd9a7d8ec85d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala @@ -34,9 +34,8 @@ class FunctionsSuite extends MLTest { (Vectors.sparse(3, Seq((0, 2.0), (2, 3.0))), OldVectors.sparse(3, Seq((0, 20.0), (2, 30.0)))) ).toDF("vec", "oldVec") - val result = df.select(vector_to_array('vec), vector_to_array('oldVec)) - .as[(Seq[Double], Seq[Double])] - .collect().toSeq + val result = df.select(vector_to_array('vec), vector_to_array('oldVec)) + .as[(Seq[Double], Seq[Double])].collect().toSeq val expected = Seq( (Seq(1.0, 2.0, 3.0), Seq(10.0, 20.0, 30.0)), @@ -50,7 +49,6 @@ class FunctionsSuite extends MLTest { (null, null, 0) ).toDF("vec", "oldVec", "label") - for ((colName, valType) <- Seq( ("vec", "null"), ("oldVec", "null"), ("label", "java.lang.Integer"))) { val thrown1 = intercept[SparkException] { @@ -61,5 +59,32 @@ class FunctionsSuite extends MLTest { "`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " + s"but got ${valType}")) } + + val df3 = Seq( + (Vectors.dense(1.0, 2.0, 3.0), OldVectors.dense(10.0, 20.0, 30.0)), + (Vectors.sparse(3, Seq((0, 2.0), (2, 3.0))), OldVectors.sparse(3, Seq((0, 20.0), (2, 30.0)))) + ).toDF("vec", "oldVec") + val dfArrayFloat = df3.select( + vector_to_array('vec, dtype = "float32"), vector_to_array('oldVec, dtype = "float32")) + + // Check values are correct + val result3 = dfArrayFloat.as[(Seq[Float], Seq[Float])].collect().toSeq + + val expected3 = Seq( + (Seq(1.0, 2.0, 3.0), Seq(10.0, 20.0, 30.0)), + (Seq(2.0, 0.0, 3.0), Seq(20.0, 0.0, 30.0)) + ) + assert(result3 === expected3) + + // Check data types are correct + assert(dfArrayFloat.schema.simpleString === + "struct,UDF(oldVec):array>") + + val thrown2 = intercept[IllegalArgumentException] { + df3.select( + vector_to_array('vec, dtype = "float16"), vector_to_array('oldVec, dtype = "float16")) + } + assert(thrown2.getMessage.contains( + s"Unsupported dtype: float16. Valid values: float64, float32.")) } } diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py index 2b4d8ddcd00a8..ec164f34bc4db 100644 --- a/python/pyspark/ml/functions.py +++ b/python/pyspark/ml/functions.py @@ -19,10 +19,15 @@ from pyspark.sql.column import Column, _to_java_column -@since(3.0) -def vector_to_array(col): +@since("3.0.0") +def vector_to_array(col, dtype="float64"): """ Converts a column of MLlib sparse/dense vectors into a column of dense arrays. + :param col: A string of the column name or a Column + :param dtype: The data type of the output array. Valid values: "float64" or "float32". + :return: The converted column of dense arrays. + + .. versionadded:: 3.0.0 >>> from pyspark.ml.linalg import Vectors >>> from pyspark.ml.functions import vector_to_array @@ -32,14 +37,26 @@ def vector_to_array(col): ... (Vectors.sparse(3, [(0, 2.0), (2, 3.0)]), ... OldVectors.sparse(3, [(0, 20.0), (2, 30.0)]))], ... ["vec", "oldVec"]) - >>> df.select(vector_to_array("vec").alias("vec"), - ... vector_to_array("oldVec").alias("oldVec")).collect() + >>> df1 = df.select(vector_to_array("vec").alias("vec"), + ... vector_to_array("oldVec").alias("oldVec")) + >>> df1.collect() + [Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]), + Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])] + >>> df2 = df.select(vector_to_array("vec", "float32").alias("vec"), + ... vector_to_array("oldVec", "float32").alias("oldVec")) + >>> df2.collect() [Row(vec=[1.0, 2.0, 3.0], oldVec=[10.0, 20.0, 30.0]), Row(vec=[2.0, 0.0, 3.0], oldVec=[20.0, 0.0, 30.0])] + >>> df1.schema.fields + [StructField(vec,ArrayType(DoubleType,false),false), + StructField(oldVec,ArrayType(DoubleType,false),false)] + >>> df2.schema.fields + [StructField(vec,ArrayType(FloatType,false),false), + StructField(oldVec,ArrayType(FloatType,false),false)] """ sc = SparkContext._active_spark_context return Column( - sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col))) + sc._jvm.org.apache.spark.ml.functions.vector_to_array(_to_java_column(col), dtype)) def _test(): From 3c4044ea77fe3b1268b52744cd4f1ae61f17a9a8 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 13 Feb 2020 10:53:55 -0800 Subject: [PATCH 036/185] [SPARK-30703][SQL][DOCS] Add a document for the ANSI mode ### What changes were proposed in this pull request? This pr intends to add a document for the ANSI mode; Screen Shot 2020-02-13 at 8 08 52 Screen Shot 2020-02-13 at 8 09 13 Screen Shot 2020-02-13 at 8 09 26 Screen Shot 2020-02-13 at 8 09 38 ### Why are the changes needed? For better document coverage and usability. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? N/A Closes #27489 from maropu/SPARK-30703. Authored-by: Takeshi Yamamuro Signed-off-by: Gengliang Wang --- docs/_data/menu-sql.yaml | 11 +- ...keywords.md => sql-ref-ansi-compliance.md} | 125 +++++++++++++++++- docs/sql-ref-arithmetic-ops.md | 22 --- 3 files changed, 132 insertions(+), 26 deletions(-) rename docs/{sql-keywords.md => sql-ref-ansi-compliance.md} (82%) delete mode 100644 docs/sql-ref-arithmetic-ops.md diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 241ec399d7bd5..1e343f630f88e 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -80,6 +80,15 @@ url: sql-ref-null-semantics.html - text: NaN Semantics url: sql-ref-nan-semantics.html + - text: ANSI Compliance + url: sql-ref-ansi-compliance.html + subitems: + - text: Arithmetic Operations + url: sql-ref-ansi-compliance.html#arithmetic-operations + - text: Type Conversion + url: sql-ref-ansi-compliance.html#type-conversion + - text: SQL Keywords + url: sql-ref-ansi-compliance.html#sql-keywords - text: SQL Syntax url: sql-ref-syntax.html subitems: @@ -214,5 +223,3 @@ url: sql-ref-syntax-aux-resource-mgmt-list-file.html - text: LIST JAR url: sql-ref-syntax-aux-resource-mgmt-list-jar.html - - text: Arithmetic operations - url: sql-ref-arithmetic-ops.html diff --git a/docs/sql-keywords.md b/docs/sql-ref-ansi-compliance.md similarity index 82% rename from docs/sql-keywords.md rename to docs/sql-ref-ansi-compliance.md index 9e4a3c54100c6..d02383518b967 100644 --- a/docs/sql-keywords.md +++ b/docs/sql-ref-ansi-compliance.md @@ -1,7 +1,7 @@ --- layout: global -title: Spark SQL Keywords -displayTitle: Spark SQL Keywords +title: ANSI Compliance +displayTitle: ANSI Compliance license: | Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with @@ -19,6 +19,127 @@ license: | limitations under the License. --- +Spark SQL has two options to comply with the SQL standard: `spark.sql.ansi.enabled` and `spark.sql.storeAssignmentPolicy` (See a table below for details). +When `spark.sql.ansi.enabled` is set to `true`, Spark SQL follows the standard in basic behaviours (e.g., arithmetic operations, type conversion, and SQL parsing). +Moreover, Spark SQL has an independent option to control implicit casting behaviours when inserting rows in a table. +The casting behaviours are defined as store assignment rules in the standard. +When `spark.sql.storeAssignmentPolicy` is set to `ANSI`, Spark SQL complies with the ANSI store assignment rules. + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.ansi.enabledfalse + When true, Spark tries to conform to the ANSI SQL specification: + 1. Spark will throw a runtime exception if an overflow occurs in any operation on integral/decimal field. + 2. Spark will forbid using the reserved keywords of ANSI SQL as identifiers in the SQL parser. +
spark.sql.storeAssignmentPolicyANSI + When inserting a value into a column with different data type, Spark will perform type coercion. + Currently, we support 3 policies for the type coercion rules: ANSI, legacy and strict. With ANSI policy, + Spark performs the type coercion as per ANSI SQL. In practice, the behavior is mostly the same as PostgreSQL. + It disallows certain unreasonable type conversions such as converting string to int or double to boolean. + With legacy policy, Spark allows the type coercion as long as it is a valid Cast, which is very loose. + e.g. converting string to int or double to boolean is allowed. + It is also the only behavior in Spark 2.x and it is compatible with Hive. + With strict policy, Spark doesn't allow any possible precision loss or data truncation in type coercion, + e.g. converting double to int or decimal to double is not allowed. +
+ +The following subsections present behaviour changes in arithmetic operations, type conversions, and SQL parsing when the ANSI mode enabled. + +### Arithmetic Operations + +In Spark SQL, arithmetic operations performed on numeric types (with the exception of decimal) are not checked for overflows by default. +This means that in case an operation causes overflows, the result is the same that the same operation returns in a Java/Scala program (e.g., if the sum of 2 integers is higher than the maximum value representable, the result is a negative number). +On the other hand, Spark SQL returns null for decimal overflows. +When `spark.sql.ansi.enabled` is set to `true` and an overflow occurs in numeric and interval arithmetic operations, it throws an arithmetic exception at runtime. + +{% highlight sql %} +-- `spark.sql.ansi.enabled=true` +SELECT 2147483647 + 1; + + java.lang.ArithmeticException: integer overflow + +-- `spark.sql.ansi.enabled=false` +SELECT 2147483647 + 1; + + +----------------+ + |(2147483647 + 1)| + +----------------+ + | -2147483648| + +----------------+ + +{% endhighlight %} + +### Type Conversion + +Spark SQL has three kinds of type conversions: explicit casting, type coercion, and store assignment casting. +When `spark.sql.ansi.enabled` is set to `true`, explicit casting by `CAST` syntax throws a runtime exception for illegal cast patterns defined in the standard, e.g. casts from a string to an integer. +On the other hand, `INSERT INTO` syntax throws an analysis exception when the ANSI mode enabled via `spark.sql.storeAssignmentPolicy=ANSI`. + +Currently, the ANSI mode affects explicit casting and assignment casting only. +In future releases, the behaviour of type coercion might change along with the other two type conversion rules. + +{% highlight sql %} +-- Examples of explicit casting + +-- `spark.sql.ansi.enabled=true` +SELECT CAST('a' AS INT); + + java.lang.NumberFormatException: invalid input syntax for type numeric: a + +SELECT CAST(2147483648L AS INT); + + java.lang.ArithmeticException: Casting 2147483648 to int causes overflow + +-- `spark.sql.ansi.enabled=false` (This is a default behaviour) +SELECT CAST('a' AS INT); + + +--------------+ + |CAST(a AS INT)| + +--------------+ + | null| + +--------------+ + +SELECT CAST(2147483648L AS INT); + + +-----------------------+ + |CAST(2147483648 AS INT)| + +-----------------------+ + | -2147483648| + +-----------------------+ + +-- Examples of store assignment rules +CREATE TABLE t (v INT); + +-- `spark.sql.storeAssignmentPolicy=ANSI` +INSERT INTO t VALUES ('1'); + + org.apache.spark.sql.AnalysisException: Cannot write incompatible data to table '`default`.`t`': + - Cannot safely cast 'v': StringType to IntegerType; + +-- `spark.sql.storeAssignmentPolicy=LEGACY` (This is a legacy behaviour until Spark 2.x) +INSERT INTO t VALUES ('1'); +SELECT * FROM t; + + +---+ + | v| + +---+ + | 1| + +---+ + +{% endhighlight %} + +### SQL Keywords + When `spark.sql.ansi.enabled` is true, Spark SQL will use the ANSI mode parser. In this mode, Spark SQL has two kinds of keywords: * Reserved keywords: Keywords that are reserved and can't be used as identifiers for table, view, column, function, alias, etc. diff --git a/docs/sql-ref-arithmetic-ops.md b/docs/sql-ref-arithmetic-ops.md deleted file mode 100644 index 7bc8ffe31c990..0000000000000 --- a/docs/sql-ref-arithmetic-ops.md +++ /dev/null @@ -1,22 +0,0 @@ ---- -layout: global -title: Arithmetic Operations -displayTitle: Arithmetic Operations -license: | - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. ---- - -Operations performed on numeric types (with the exception of decimal) are not checked for overflow. This means that in case an operation causes an overflow, the result is the same that the same operation returns in a Java/Scala program (eg. if the sum of 2 integers is higher than the maximum value representable, the result is a negative number). From a4ceea6868002b88161517b14b94a2006be8af1b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Feb 2020 20:09:24 +0100 Subject: [PATCH 037/185] [SPARK-30751][SQL] Combine the skewed readers into one in AQE skew join optimizations ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/26434 This PR use one special shuffle reader for skew join, so that we only have one join after optimization. In order to do that, this PR 1. add a very general `CustomShuffledRowRDD` which support all kind of partition arrangement. 2. move the logic of coalescing shuffle partitions to a util function, and call it during skew join optimization, to totally decouple with the `ReduceNumShufflePartitions` rule. It's too complicated to interfere skew join with `ReduceNumShufflePartitions`, as you need to consider the size of split partitions which don't respect target size already. ### Why are the changes needed? The current skew join optimization has a serious performance issue: the size of the query plan depends on the number and size of skewed partitions. ### Does this PR introduce any user-facing change? no ### How was this patch tested? existing tests test UI manually: ![image](https://user-images.githubusercontent.com/3182036/74357390-cfb30480-4dfa-11ea-83f6-825d1b9379ca.png) explain output ``` AdaptiveSparkPlan(isFinalPlan=true) +- OverwriteByExpression org.apache.spark.sql.execution.datasources.noop.NoopTable$403a2ed5, [AlwaysTrue()], org.apache.spark.sql.util.CaseInsensitiveStringMap1f +- *(5) SortMergeJoin(skew=true) [key1#2L], [key2#6L], Inner :- *(3) Sort [key1#2L ASC NULLS FIRST], false, 0 : +- SkewJoinShuffleReader 2 skewed partitions with size(max=5 KB, min=5 KB, avg=5 KB) : +- ShuffleQueryStage 0 : +- Exchange hashpartitioning(key1#2L, 200), true, [id=#53] : +- *(1) Project [(id#0L % 2) AS key1#2L] : +- *(1) Filter isnotnull((id#0L % 2)) : +- *(1) Range (0, 100000, step=1, splits=6) +- *(4) Sort [key2#6L ASC NULLS FIRST], false, 0 +- SkewJoinShuffleReader 2 skewed partitions with size(max=5 KB, min=5 KB, avg=5 KB) +- ShuffleQueryStage 1 +- Exchange hashpartitioning(key2#6L, 200), true, [id=#64] +- *(2) Project [((id#4L % 2) + 1) AS key2#6L] +- *(2) Filter isnotnull(((id#4L % 2) + 1)) +- *(2) Range (0, 100000, step=1, splits=6) ``` Closes #27493 from cloud-fan/aqe. Authored-by: Wenchen Fan Signed-off-by: herman --- .../spark/sql/execution/ShuffledRowRDD.scala | 23 +- .../adaptive/CustomShuffledRowRDD.scala | 113 +++++++ .../adaptive/OptimizeLocalShuffleReader.scala | 2 +- .../adaptive/OptimizeSkewedJoin.scala | 276 +++++++++++------- .../adaptive/ReduceNumShufflePartitions.scala | 157 ++-------- .../adaptive/ShufflePartitionsCoalescer.scala | 112 +++++++ .../adaptive/SkewedShuffledRowRDD.scala | 78 ----- .../exchange/ShuffleExchangeExec.scala | 21 +- .../execution/joins/SortMergeJoinExec.scala | 13 +- .../ReduceNumShufflePartitionsSuite.scala | 210 +------------ .../ShufflePartitionsCoalescerSuite.scala | 220 ++++++++++++++ .../adaptive/AdaptiveQueryExecSuite.scala | 219 +++++--------- 12 files changed, 741 insertions(+), 703 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index efa493923ccc1..4c19f95796d04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -116,7 +116,7 @@ class CoalescedPartitioner(val parent: Partitioner, val partitionStartIndices: A class ShuffledRowRDD( var dependency: ShuffleDependency[Int, InternalRow, InternalRow], metrics: Map[String, SQLMetric], - specifiedPartitionIndices: Option[Array[(Int, Int)]] = None) + specifiedPartitionStartIndices: Option[Array[Int]] = None) extends RDD[InternalRow](dependency.rdd.context, Nil) { if (SQLConf.get.fetchShuffleBlocksInBatchEnabled) { @@ -126,8 +126,8 @@ class ShuffledRowRDD( private[this] val numPreShufflePartitions = dependency.partitioner.numPartitions - private[this] val partitionStartIndices: Array[Int] = specifiedPartitionIndices match { - case Some(indices) => indices.map(_._1) + private[this] val partitionStartIndices: Array[Int] = specifiedPartitionStartIndices match { + case Some(indices) => indices case None => // When specifiedPartitionStartIndices is not defined, every post-shuffle partition // corresponds to a pre-shuffle partition. @@ -142,15 +142,16 @@ class ShuffledRowRDD( override val partitioner: Option[Partitioner] = Some(part) override def getPartitions: Array[Partition] = { - specifiedPartitionIndices match { - case Some(indices) => - Array.tabulate[Partition](indices.length) { i => - new ShuffledRowRDDPartition(i, indices(i)._1, indices(i)._2) - } - case None => - Array.tabulate[Partition](numPreShufflePartitions) { i => - new ShuffledRowRDDPartition(i, i, i + 1) + assert(partitionStartIndices.length == part.numPartitions) + Array.tabulate[Partition](partitionStartIndices.length) { i => + val startIndex = partitionStartIndices(i) + val endIndex = + if (i < partitionStartIndices.length - 1) { + partitionStartIndices(i + 1) + } else { + numPreShufflePartitions } + new ShuffledRowRDDPartition(i, startIndex, endIndex) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala new file mode 100644 index 0000000000000..5aba57443d632 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CustomShuffledRowRDD.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.{Dependency, MapOutputTrackerMaster, Partition, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} + +sealed trait ShufflePartitionSpec + +// A partition that reads data of one reducer. +case class SinglePartitionSpec(reducerIndex: Int) extends ShufflePartitionSpec + +// A partition that reads data of multiple reducers, from `startReducerIndex` (inclusive) to +// `endReducerIndex` (exclusive). +case class CoalescedPartitionSpec( + startReducerIndex: Int, endReducerIndex: Int) extends ShufflePartitionSpec + +// A partition that reads partial data of one reducer, from `startMapIndex` (inclusive) to +// `endMapIndex` (exclusive). +case class PartialPartitionSpec( + reducerIndex: Int, startMapIndex: Int, endMapIndex: Int) extends ShufflePartitionSpec + +private final case class CustomShufflePartition( + index: Int, spec: ShufflePartitionSpec) extends Partition + +// TODO: merge this with `ShuffledRowRDD`, and replace `LocalShuffledRowRDD` with this RDD. +class CustomShuffledRowRDD( + var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + metrics: Map[String, SQLMetric], + partitionSpecs: Array[ShufflePartitionSpec]) + extends RDD[InternalRow](dependency.rdd.context, Nil) { + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override def clearDependencies() { + super.clearDependencies() + dependency = null + } + + override def getPartitions: Array[Partition] = { + Array.tabulate[Partition](partitionSpecs.length) { i => + CustomShufflePartition(i, partitionSpecs(i)) + } + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + partition.asInstanceOf[CustomShufflePartition].spec match { + case SinglePartitionSpec(reducerIndex) => + tracker.getPreferredLocationsForShuffle(dependency, reducerIndex) + + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + startReducerIndex.until(endReducerIndex).flatMap { reducerIndex => + tracker.getPreferredLocationsForShuffle(dependency, reducerIndex) + } + + case PartialPartitionSpec(_, startMapIndex, endMapIndex) => + tracker.getMapLocation(dependency, startMapIndex, endMapIndex) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() + // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, + // as well as the `tempMetrics` for basic shuffle metrics. + val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) + val reader = split.asInstanceOf[CustomShufflePartition].spec match { + case SinglePartitionSpec(reducerIndex) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager.getReader( + dependency.shuffleHandle, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + + case PartialPartitionSpec(reducerIndex, startMapIndex, endMapIndex) => + SparkEnv.get.shuffleManager.getReaderForRange( + dependency.shuffleHandle, + startMapIndex, + endMapIndex, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + } + reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index a8d8f358ab660..e95441e28aafe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -71,7 +71,7 @@ case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { plan match { case c @ CoalescedShuffleReaderExec(s: ShuffleQueryStageExec, _) => LocalShuffleReaderExec( - s, getPartitionStartIndices(s, Some(c.partitionIndices.length))) + s, getPartitionStartIndices(s, Some(c.partitionStartIndices.length))) case s: ShuffleQueryStageExec => LocalShuffleReaderExec(s, getPartitionStartIndices(s, None)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index 74b7fbd317fc8..a716497c274b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.adaptive import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.commons.io.FileUtils + import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -44,11 +46,7 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { * partition size * spark.sql.adaptive.skewedPartitionFactor and also larger than * spark.sql.adaptive.skewedPartitionSizeThreshold. */ - private def isSkewed( - stats: MapOutputStatistics, - partitionId: Int, - medianSize: Long): Boolean = { - val size = stats.bytesByPartitionId(partitionId) + private def isSkewed(size: Long, medianSize: Long): Boolean = { size > medianSize * conf.getConf(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_FACTOR) && size > conf.getConf(SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD) } @@ -108,12 +106,12 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { stage.resultOption.get.asInstanceOf[MapOutputStatistics] } - private def supportSplitOnLeftPartition(joinType: JoinType) = { + private def canSplitLeftSide(joinType: JoinType) = { joinType == Inner || joinType == Cross || joinType == LeftSemi || joinType == LeftAnti || joinType == LeftOuter } - private def supportSplitOnRightPartition(joinType: JoinType) = { + private def canSplitRightSide(joinType: JoinType) = { joinType == Inner || joinType == Cross || joinType == RightOuter } @@ -130,17 +128,18 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { * 1. Check whether the shuffle partition is skewed based on the median size * and the skewed partition threshold in origin smj. * 2. Assuming partition0 is skewed in left side, and it has 5 mappers (Map0, Map1...Map4). - * And we will split the 5 Mappers into 3 mapper ranges [(Map0, Map1), (Map2, Map3), (Map4)] + * And we may split the 5 Mappers into 3 mapper ranges [(Map0, Map1), (Map2, Map3), (Map4)] * based on the map size and the max split number. - * 3. Create the 3 smjs with separately reading the above mapper ranges and then join with - * the Partition0 in right side. - * 4. Finally union the above 3 split smjs and the origin smj. + * 3. Wrap the join left child with a special shuffle reader that reads each mapper range with one + * task, so total 3 tasks. + * 4. Wrap the join right child with a special shuffle reader that reads partition0 3 times by + * 3 tasks separately. */ def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp { - case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, + case smj @ SortMergeJoinExec(_, _, joinType, _, s1 @ SortExec(_, _, left: ShuffleQueryStageExec, _), s2 @ SortExec(_, _, right: ShuffleQueryStageExec, _), _) - if (supportedJoinTypes.contains(joinType)) => + if supportedJoinTypes.contains(joinType) => val leftStats = getStatistics(left) val rightStats = getStatistics(right) val numPartitions = leftStats.bytesByPartitionId.length @@ -155,61 +154,134 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { |Right side partition size: |${getSizeInfo(rightMedSize, rightStats.bytesByPartitionId.max)} """.stripMargin) + val canSplitLeft = canSplitLeftSide(joinType) + val canSplitRight = canSplitRightSide(joinType) + + val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec] + // This is used to delay the creation of non-skew partitions so that we can potentially + // coalesce them like `ReduceNumShufflePartitions` does. + val nonSkewPartitionIndices = mutable.ArrayBuffer.empty[Int] + val leftSkewDesc = new SkewDesc + val rightSkewDesc = new SkewDesc + for (partitionIndex <- 0 until numPartitions) { + val leftSize = leftStats.bytesByPartitionId(partitionIndex) + val isLeftSkew = isSkewed(leftSize, leftMedSize) && canSplitLeft + val rightSize = rightStats.bytesByPartitionId(partitionIndex) + val isRightSkew = isSkewed(rightSize, rightMedSize) && canSplitRight + if (isLeftSkew || isRightSkew) { + if (nonSkewPartitionIndices.nonEmpty) { + // As soon as we see a skew, we'll "flush" out unhandled non-skew partitions. + createNonSkewPartitions(leftStats, rightStats, nonSkewPartitionIndices).foreach { p => + leftSidePartitions += p + rightSidePartitions += p + } + nonSkewPartitionIndices.clear() + } - val skewedPartitions = mutable.HashSet[Int]() - val subJoins = mutable.ArrayBuffer[SparkPlan]() - for (partitionId <- 0 until numPartitions) { - val isLeftSkew = isSkewed(leftStats, partitionId, leftMedSize) - val isRightSkew = isSkewed(rightStats, partitionId, rightMedSize) - val leftMapIdStartIndices = if (isLeftSkew && supportSplitOnLeftPartition(joinType)) { - getMapStartIndices(left, partitionId) - } else { - Array(0) - } - val rightMapIdStartIndices = if (isRightSkew && supportSplitOnRightPartition(joinType)) { - getMapStartIndices(right, partitionId) - } else { - Array(0) - } + val leftParts = if (isLeftSkew) { + leftSkewDesc.addPartitionSize(leftSize) + createSkewPartitions( + partitionIndex, + getMapStartIndices(left, partitionIndex), + getNumMappers(left)) + } else { + Seq(SinglePartitionSpec(partitionIndex)) + } - if (leftMapIdStartIndices.length > 1 || rightMapIdStartIndices.length > 1) { - skewedPartitions += partitionId - for (i <- 0 until leftMapIdStartIndices.length; - j <- 0 until rightMapIdStartIndices.length) { - val leftEndMapId = if (i == leftMapIdStartIndices.length - 1) { - getNumMappers(left) - } else { - leftMapIdStartIndices(i + 1) - } - val rightEndMapId = if (j == rightMapIdStartIndices.length - 1) { - getNumMappers(right) - } else { - rightMapIdStartIndices(j + 1) + val rightParts = if (isRightSkew) { + rightSkewDesc.addPartitionSize(rightSize) + createSkewPartitions( + partitionIndex, + getMapStartIndices(right, partitionIndex), + getNumMappers(right)) + } else { + Seq(SinglePartitionSpec(partitionIndex)) + } + + for { + leftSidePartition <- leftParts + rightSidePartition <- rightParts + } { + leftSidePartitions += leftSidePartition + rightSidePartitions += rightSidePartition + } + } else { + // Add to `nonSkewPartitionIndices` first, and add real partitions later, in case we can + // coalesce the non-skew partitions. + nonSkewPartitionIndices += partitionIndex + // If this is the last partition, add real partition immediately. + if (partitionIndex == numPartitions - 1) { + createNonSkewPartitions(leftStats, rightStats, nonSkewPartitionIndices).foreach { p => + leftSidePartitions += p + rightSidePartitions += p } - // TODO: we may can optimize the sort merge join to broad cast join after - // obtaining the raw data size of per partition, - val leftSkewedReader = SkewedPartitionReaderExec( - left, partitionId, leftMapIdStartIndices(i), leftEndMapId) - val rightSkewedReader = SkewedPartitionReaderExec(right, partitionId, - rightMapIdStartIndices(j), rightEndMapId) - subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, - s1.copy(child = leftSkewedReader), s2.copy(child = rightSkewedReader), true) + nonSkewPartitionIndices.clear() } } } - logDebug(s"number of skewed partitions is ${skewedPartitions.size}") - if (skewedPartitions.nonEmpty) { - val optimizedSmj = smj.copy( - left = s1.copy(child = PartialShuffleReaderExec(left, skewedPartitions.toSet)), - right = s2.copy(child = PartialShuffleReaderExec(right, skewedPartitions.toSet)), - isPartial = true) - subJoins += optimizedSmj - UnionExec(subJoins) + + logDebug("number of skewed partitions: " + + s"left ${leftSkewDesc.numPartitions}, right ${rightSkewDesc.numPartitions}") + if (leftSkewDesc.numPartitions > 0 || rightSkewDesc.numPartitions > 0) { + val newLeft = SkewJoinShuffleReaderExec( + left, leftSidePartitions.toArray, leftSkewDesc.toString) + val newRight = SkewJoinShuffleReaderExec( + right, rightSidePartitions.toArray, rightSkewDesc.toString) + smj.copy( + left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true) } else { smj } } + private def createNonSkewPartitions( + leftStats: MapOutputStatistics, + rightStats: MapOutputStatistics, + nonSkewPartitionIndices: Seq[Int]): Seq[ShufflePartitionSpec] = { + assert(nonSkewPartitionIndices.nonEmpty) + if (nonSkewPartitionIndices.length == 1) { + Seq(SinglePartitionSpec(nonSkewPartitionIndices.head)) + } else { + val startIndices = ShufflePartitionsCoalescer.coalescePartitions( + Array(leftStats, rightStats), + firstPartitionIndex = nonSkewPartitionIndices.head, + // `lastPartitionIndex` is exclusive. + lastPartitionIndex = nonSkewPartitionIndices.last + 1, + advisoryTargetSize = conf.targetPostShuffleInputSize) + startIndices.indices.map { i => + val startIndex = startIndices(i) + val endIndex = if (i == startIndices.length - 1) { + // `endIndex` is exclusive. + nonSkewPartitionIndices.last + 1 + } else { + startIndices(i + 1) + } + // Do not create `CoalescedPartitionSpec` if only need to read a singe partition. + if (startIndex + 1 == endIndex) { + SinglePartitionSpec(startIndex) + } else { + CoalescedPartitionSpec(startIndex, endIndex) + } + } + } + } + + private def createSkewPartitions( + reducerIndex: Int, + mapStartIndices: Array[Int], + numMappers: Int): Seq[PartialPartitionSpec] = { + mapStartIndices.indices.map { i => + val startMapIndex = mapStartIndices(i) + val endMapIndex = if (i == mapStartIndices.length - 1) { + numMappers + } else { + mapStartIndices(i + 1) + } + PartialPartitionSpec(reducerIndex, startMapIndex, endMapIndex) + } + } + override def apply(plan: SparkPlan): SparkPlan = { if (!conf.getConf(SQLConf.ADAPTIVE_EXECUTION_SKEWED_JOIN_ENABLED)) { return plan @@ -248,79 +320,69 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { } } -/** - * A wrapper of shuffle query stage, which submits one reduce task to read a single - * shuffle partition 'partitionIndex' produced by the mappers in range [startMapIndex, endMapIndex). - * This is used to increase the parallelism when reading skewed partitions. - * - * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange - * node during canonicalization. - * @param partitionIndex The pre shuffle partition index. - * @param startMapIndex The start map index. - * @param endMapIndex The end map index. - */ -case class SkewedPartitionReaderExec( - child: QueryStageExec, - partitionIndex: Int, - startMapIndex: Int, - endMapIndex: Int) extends LeafExecNode { +private class SkewDesc { + private[this] var numSkewedPartitions: Int = 0 + private[this] var totalSize: Long = 0 + private[this] var maxSize: Long = 0 + private[this] var minSize: Long = 0 - override def output: Seq[Attribute] = child.output + def numPartitions: Int = numSkewedPartitions - override def outputPartitioning: Partitioning = { - UnknownPartitioning(1) + def addPartitionSize(size: Long): Unit = { + if (numSkewedPartitions == 0) { + maxSize = size + minSize = size + } + numSkewedPartitions += 1 + totalSize += size + if (size > maxSize) maxSize = size + if (size < minSize) minSize = size } - private var cachedSkewedShuffleRDD: SkewedShuffledRowRDD = null - override def doExecute(): RDD[InternalRow] = { - if (cachedSkewedShuffleRDD == null) { - cachedSkewedShuffleRDD = child match { - case stage: ShuffleQueryStageExec => - stage.shuffle.createSkewedShuffleRDD(partitionIndex, startMapIndex, endMapIndex) - case _ => - throw new IllegalStateException("operating on canonicalization plan") - } + override def toString: String = { + if (numSkewedPartitions == 0) { + "no skewed partition" + } else { + val maxSizeStr = FileUtils.byteCountToDisplaySize(maxSize) + val minSizeStr = FileUtils.byteCountToDisplaySize(minSize) + val avgSizeStr = FileUtils.byteCountToDisplaySize(totalSize / numSkewedPartitions) + s"$numSkewedPartitions skewed partitions with " + + s"size(max=$maxSizeStr, min=$minSizeStr, avg=$avgSizeStr)" } - cachedSkewedShuffleRDD } } /** - * A wrapper of shuffle query stage, which skips some partitions when reading the shuffle blocks. + * A wrapper of shuffle query stage, which follows the given partition arrangement. * * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange node during * canonicalization. - * @param excludedPartitions The partitions to skip when reading. + * @param partitionSpecs The partition specs that defines the arrangement. + * @param skewDesc The description of the skewed partitions. */ -case class PartialShuffleReaderExec( - child: QueryStageExec, - excludedPartitions: Set[Int]) extends UnaryExecNode { +case class SkewJoinShuffleReaderExec( + child: SparkPlan, + partitionSpecs: Array[ShufflePartitionSpec], + skewDesc: String) extends UnaryExecNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = { - UnknownPartitioning(1) + UnknownPartitioning(partitionSpecs.length) } - private def shuffleExchange(): ShuffleExchangeExec = child match { - case stage: ShuffleQueryStageExec => stage.shuffle - case _ => - throw new IllegalStateException("operating on canonicalization plan") - } - - private def getPartitionIndexRanges(): Array[(Int, Int)] = { - val length = shuffleExchange().shuffleDependency.partitioner.numPartitions - (0 until length).filterNot(excludedPartitions.contains).map(i => (i, i + 1)).toArray - } + override def stringArgs: Iterator[Any] = Iterator(skewDesc) private var cachedShuffleRDD: RDD[InternalRow] = null - override def doExecute(): RDD[InternalRow] = { + override protected def doExecute(): RDD[InternalRow] = { if (cachedShuffleRDD == null) { - cachedShuffleRDD = if (excludedPartitions.isEmpty) { - child.execute() - } else { - shuffleExchange().createShuffledRDD(Some(getPartitionIndexRanges())) + cachedShuffleRDD = child match { + case stage: ShuffleQueryStageExec => + new CustomShuffledRowRDD( + stage.shuffle.shuffleDependency, stage.shuffle.readMetrics, partitionSpecs) + case _ => + throw new IllegalStateException("operating on canonicalization plan") } } cachedShuffleRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala index 2c50b638b4d45..5bbcb14e008d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.adaptive -import scala.collection.mutable.{ArrayBuffer, HashSet} - import org.apache.spark.MapOutputStatistics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -29,24 +27,8 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf /** - * A rule to adjust the post shuffle partitions based on the map output statistics. - * - * The strategy used to determine the number of post-shuffle partitions is described as follows. - * To determine the number of post-shuffle partitions, we have a target input size for a - * post-shuffle partition. Once we have size statistics of all pre-shuffle partitions, we will do - * a pass of those statistics and pack pre-shuffle partitions with continuous indices to a single - * post-shuffle partition until adding another pre-shuffle partition would cause the size of a - * post-shuffle partition to be greater than the target size. - * - * For example, we have two stages with the following pre-shuffle partition size statistics: - * stage 1: [100 MiB, 20 MiB, 100 MiB, 10MiB, 30 MiB] - * stage 2: [10 MiB, 10 MiB, 70 MiB, 5 MiB, 5 MiB] - * assuming the target input size is 128 MiB, we will have four post-shuffle partitions, - * which are: - * - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MiB) - * - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MiB) - * - post-shuffle partition 2: pre-shuffle partition 2 (size 170 MiB) - * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MiB) + * A rule to reduce the post shuffle partitions based on the map output statistics, which can + * avoid many small reduce tasks that hurt performance. */ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { @@ -54,28 +36,21 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { if (!conf.reducePostShufflePartitionsEnabled) { return plan } - // 'SkewedShufflePartitionReader' is added by us, so it's safe to ignore it when changing - // number of reducers. - val leafNodes = plan.collectLeaves().filter(!_.isInstanceOf[SkewedPartitionReaderExec]) - if (!leafNodes.forall(_.isInstanceOf[QueryStageExec])) { + if (!plan.collectLeaves().forall(_.isInstanceOf[QueryStageExec])) { // If not all leaf nodes are query stages, it's not safe to reduce the number of // shuffle partitions, because we may break the assumption that all children of a spark plan // have same number of output partitions. return plan } - def collectShuffles(plan: SparkPlan): Seq[SparkPlan] = plan match { + def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match { case _: LocalShuffleReaderExec => Nil - case p: PartialShuffleReaderExec => Seq(p) + case _: SkewJoinShuffleReaderExec => Nil case stage: ShuffleQueryStageExec => Seq(stage) - case _ => plan.children.flatMap(collectShuffles) + case _ => plan.children.flatMap(collectShuffleStages) } - val shuffles = collectShuffles(plan) - val shuffleStages = shuffles.map { - case PartialShuffleReaderExec(s: ShuffleQueryStageExec, _) => s - case s: ShuffleQueryStageExec => s - } + val shuffleStages = collectShuffleStages(plan) // ShuffleExchanges introduced by repartition do not support changing the number of partitions. // We change the number of partitions in the stage only if all the ShuffleExchanges support it. if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) { @@ -94,110 +69,27 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { // partition) and a result of a SortMergeJoin (multiple partitions). val distinctNumPreShufflePartitions = validMetrics.map(stats => stats.bytesByPartitionId.length).distinct - val distinctExcludedPartitions = shuffles.map { - case PartialShuffleReaderExec(_, excludedPartitions) => excludedPartitions - case _: ShuffleQueryStageExec => Set.empty[Int] - }.distinct - if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1 - && distinctExcludedPartitions.length == 1) { - val excludedPartitions = distinctExcludedPartitions.head - val partitionIndices = estimatePartitionStartAndEndIndices( - validMetrics.toArray, excludedPartitions) + if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) { + val partitionStartIndices = ShufflePartitionsCoalescer.coalescePartitions( + validMetrics.toArray, + firstPartitionIndex = 0, + lastPartitionIndex = distinctNumPreShufflePartitions.head, + advisoryTargetSize = conf.targetPostShuffleInputSize, + minNumPartitions = conf.minNumPostShufflePartitions) // This transformation adds new nodes, so we must use `transformUp` here. - // Even for shuffle exchange whose input RDD has 0 partition, we should still update its - // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same - // number of output partitions. - val visitedStages = HashSet.empty[Int] - plan.transformDown { - // Replace `PartialShuffleReaderExec` with `CoalescedShuffleReaderExec`, which keeps the - // "excludedPartition" requirement and also merges some partitions. - case PartialShuffleReaderExec(stage: ShuffleQueryStageExec, _) => - visitedStages.add(stage.id) - CoalescedShuffleReaderExec(stage, partitionIndices) - - // We are doing `transformDown`, so the `ShuffleQueryStageExec` may already be optimized - // and wrapped by `CoalescedShuffleReaderExec`. - case stage: ShuffleQueryStageExec if !visitedStages.contains(stage.id) => - visitedStages.add(stage.id) - CoalescedShuffleReaderExec(stage, partitionIndices) + val stageIds = shuffleStages.map(_.id).toSet + plan.transformUp { + // even for shuffle exchange whose input RDD has 0 partition, we should still update its + // `partitionStartIndices`, so that all the leaf shuffles in a stage have the same + // number of output partitions. + case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) => + CoalescedShuffleReaderExec(stage, partitionStartIndices) } } else { plan } } } - - /** - * Estimates partition start and end indices for post-shuffle partitions based on - * mapOutputStatistics provided by all pre-shuffle stages and skip the omittedPartitions - * already handled in skewed partition optimization. - */ - // visible for testing. - private[sql] def estimatePartitionStartAndEndIndices( - mapOutputStatistics: Array[MapOutputStatistics], - excludedPartitions: Set[Int] = Set.empty): Array[(Int, Int)] = { - val minNumPostShufflePartitions = conf.minNumPostShufflePartitions - excludedPartitions.size - val advisoryTargetPostShuffleInputSize = conf.targetPostShuffleInputSize - // If minNumPostShufflePartitions is defined, it is possible that we need to use a - // value less than advisoryTargetPostShuffleInputSize as the target input size of - // a post shuffle task. - val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum - // The max at here is to make sure that when we have an empty table, we - // only have a single post-shuffle partition. - // There is no particular reason that we pick 16. We just need a number to - // prevent maxPostShuffleInputSize from being set to 0. - val maxPostShuffleInputSize = math.max( - math.ceil(totalPostShuffleInputSize / minNumPostShufflePartitions.toDouble).toLong, 16) - val targetPostShuffleInputSize = - math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) - - logInfo( - s"advisoryTargetPostShuffleInputSize: $advisoryTargetPostShuffleInputSize, " + - s"targetPostShuffleInputSize $targetPostShuffleInputSize.") - - // Make sure we do get the same number of pre-shuffle partitions for those stages. - val distinctNumPreShufflePartitions = - mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct - // The reason that we are expecting a single value of the number of pre-shuffle partitions - // is that when we add Exchanges, we set the number of pre-shuffle partitions - // (i.e. map output partitions) using a static setting, which is the value of - // spark.sql.shuffle.partitions. Even if two input RDDs are having different - // number of partitions, they will have the same number of pre-shuffle partitions - // (i.e. map output partitions). - assert( - distinctNumPreShufflePartitions.length == 1, - "There should be only one distinct value of the number pre-shuffle partitions " + - "among registered Exchange operator.") - - val partitionStartIndices = ArrayBuffer[Int]() - val partitionEndIndices = ArrayBuffer[Int]() - val numPartitions = distinctNumPreShufflePartitions.head - val includedPartitions = (0 until numPartitions).filter(!excludedPartitions.contains(_)) - val firstStartIndex = includedPartitions(0) - partitionStartIndices += firstStartIndex - var postShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId(firstStartIndex)).sum - var i = firstStartIndex - includedPartitions.drop(1).foreach { nextPartitionIndex => - val nextShuffleInputSize = - mapOutputStatistics.map(_.bytesByPartitionId(nextPartitionIndex)).sum - // If nextPartitionIndices is skewed and omitted, or including - // the nextShuffleInputSize would exceed the target partition size, - // then start a new partition. - if (nextPartitionIndex != i + 1 || - (postShuffleInputSize + nextShuffleInputSize > targetPostShuffleInputSize)) { - partitionEndIndices += i + 1 - partitionStartIndices += nextPartitionIndex - // reset postShuffleInputSize. - postShuffleInputSize = nextShuffleInputSize - i = nextPartitionIndex - } else { - postShuffleInputSize += nextShuffleInputSize - i += 1 - } - } - partitionEndIndices += i + 1 - partitionStartIndices.zip(partitionEndIndices).toArray - } } /** @@ -206,15 +98,16 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { * * @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange node during * canonicalization. + * @param partitionStartIndices The start partition indices for the coalesced partitions. */ case class CoalescedShuffleReaderExec( child: SparkPlan, - partitionIndices: Array[(Int, Int)]) extends UnaryExecNode { + partitionStartIndices: Array[Int]) extends UnaryExecNode { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = { - UnknownPartitioning(partitionIndices.length) + UnknownPartitioning(partitionStartIndices.length) } private var cachedShuffleRDD: ShuffledRowRDD = null @@ -223,7 +116,7 @@ case class CoalescedShuffleReaderExec( if (cachedShuffleRDD == null) { cachedShuffleRDD = child match { case stage: ShuffleQueryStageExec => - stage.shuffle.createShuffledRDD(Some(partitionIndices)) + stage.shuffle.createShuffledRDD(Some(partitionStartIndices)) case _ => throw new IllegalStateException("operating on canonicalization plan") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala new file mode 100644 index 0000000000000..18f0585524aa2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsCoalescer.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.MapOutputStatistics +import org.apache.spark.internal.Logging + +object ShufflePartitionsCoalescer extends Logging { + + /** + * Coalesce the same range of partitions (`firstPartitionIndex`` to `lastPartitionIndex`, the + * start is inclusive and the end is exclusive) from multiple shuffles. This method assumes that + * all the shuffles have the same number of partitions, and the partitions of same index will be + * read together by one task. + * + * The strategy used to determine the number of coalesced partitions is described as follows. + * To determine the number of coalesced partitions, we have a target size for a coalesced + * partition. Once we have size statistics of all shuffle partitions, we will do + * a pass of those statistics and pack shuffle partitions with continuous indices to a single + * coalesced partition until adding another shuffle partition would cause the size of a + * coalesced partition to be greater than the target size. + * + * For example, we have two shuffles with the following partition size statistics: + * - shuffle 1 (5 partitions): [100 MiB, 20 MiB, 100 MiB, 10MiB, 30 MiB] + * - shuffle 2 (5 partitions): [10 MiB, 10 MiB, 70 MiB, 5 MiB, 5 MiB] + * Assuming the target size is 128 MiB, we will have 4 coalesced partitions, which are: + * - coalesced partition 0: shuffle partition 0 (size 110 MiB) + * - coalesced partition 1: shuffle partition 1 (size 30 MiB) + * - coalesced partition 2: shuffle partition 2 (size 170 MiB) + * - coalesced partition 3: shuffle partition 3 and 4 (size 50 MiB) + * + * @return An array of partition indices which represents the coalesced partitions. For example, + * [0, 2, 3] means 3 coalesced partitions: [0, 2), [2, 3), [3, lastPartitionIndex] + */ + def coalescePartitions( + mapOutputStatistics: Array[MapOutputStatistics], + firstPartitionIndex: Int, + lastPartitionIndex: Int, + advisoryTargetSize: Long, + minNumPartitions: Int = 1): Array[Int] = { + // If `minNumPartitions` is very large, it is possible that we need to use a value less than + // `advisoryTargetSize` as the target size of a coalesced task. + val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum + // The max at here is to make sure that when we have an empty table, we only have a single + // coalesced partition. + // There is no particular reason that we pick 16. We just need a number to prevent + // `maxTargetSize` from being set to 0. + val maxTargetSize = math.max( + math.ceil(totalPostShuffleInputSize / minNumPartitions.toDouble).toLong, 16) + val targetSize = math.min(maxTargetSize, advisoryTargetSize) + + logInfo(s"advisory target size: $advisoryTargetSize, actual target size $targetSize.") + + // Make sure these shuffles have the same number of partitions. + val distinctNumShufflePartitions = + mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + // The reason that we are expecting a single value of the number of shuffle partitions + // is that when we add Exchanges, we set the number of shuffle partitions + // (i.e. map output partitions) using a static setting, which is the value of + // `spark.sql.shuffle.partitions`. Even if two input RDDs are having different + // number of partitions, they will have the same number of shuffle partitions + // (i.e. map output partitions). + assert( + distinctNumShufflePartitions.length == 1, + "There should be only one distinct value of the number of shuffle partitions " + + "among registered Exchange operators.") + + val splitPoints = ArrayBuffer[Int]() + splitPoints += firstPartitionIndex + var coalescedSize = 0L + var i = firstPartitionIndex + while (i < lastPartitionIndex) { + // We calculate the total size of i-th shuffle partitions from all shuffles. + var totalSizeOfCurrentPartition = 0L + var j = 0 + while (j < mapOutputStatistics.length) { + totalSizeOfCurrentPartition += mapOutputStatistics(j).bytesByPartitionId(i) + j += 1 + } + + // If including the `totalSizeOfCurrentPartition` would exceed the target size, then start a + // new coalesced partition. + if (i > firstPartitionIndex && coalescedSize + totalSizeOfCurrentPartition > targetSize) { + splitPoints += i + // reset postShuffleInputSize. + coalescedSize = totalSizeOfCurrentPartition + } else { + coalescedSize += totalSizeOfCurrentPartition + } + i += 1 + } + + splitPoints.toArray + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala deleted file mode 100644 index 52f793b24aa17..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/SkewedShuffledRowRDD.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.adaptive - -import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} - -/** - * The [[Partition]] used by [[SkewedShuffledRowRDD]]. - */ -class SkewedShuffledRowRDDPartition(override val index: Int) extends Partition - -/** - * This is a specialized version of [[org.apache.spark.sql.execution.ShuffledRowRDD]]. This is used - * in Spark SQL adaptive execution to solve data skew issues. This RDD includes rearranged - * partitions from mappers. - * - * This RDD takes a [[ShuffleDependency]] (`dependency`), a partitionIndex - * and the range of startMapIndex to endMapIndex. - */ -class SkewedShuffledRowRDD( - var dependency: ShuffleDependency[Int, InternalRow, InternalRow], - partitionIndex: Int, - startMapIndex: Int, - endMapIndex: Int, - metrics: Map[String, SQLMetric]) - extends RDD[InternalRow](dependency.rdd.context, Nil) { - - override def getDependencies: Seq[Dependency[_]] = List(dependency) - - override def getPartitions: Array[Partition] = { - Array(new SkewedShuffledRowRDDPartition(0)) - } - - override def getPreferredLocations(partition: Partition): Seq[String] = { - val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] - tracker.getMapLocation(dependency, startMapIndex, endMapIndex) - } - - override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { - val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() - // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, - // as well as the `tempMetrics` for basic shuffle metrics. - val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) - - val reader = SparkEnv.get.shuffleManager.getReaderForRange( - dependency.shuffleHandle, - startMapIndex, - endMapIndex, - partitionIndex, - partitionIndex + 1, - context, - sqlMetricsReporter) - reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) - } - - override def clearDependencies() { - super.clearDependencies() - dependency = null - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index ffcd6c7783354..4b08da043b83e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,11 +30,11 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Divide, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{LocalShuffledRowRDD, SkewedShuffledRowRDD} +import org.apache.spark.sql.execution.adaptive.LocalShuffledRowRDD import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -49,11 +49,9 @@ case class ShuffleExchangeExec( child: SparkPlan, canChangeNumPartitions: Boolean = true) extends Exchange { - // NOTE: coordinator can be null after serialization/deserialization, - // e.g. it can be null on the Executor side private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) - private lazy val readMetrics = + private[sql] lazy val readMetrics = SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) override lazy val metrics = Map( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size") @@ -90,9 +88,8 @@ case class ShuffleExchangeExec( writeMetrics) } - def createShuffledRDD( - partitionRanges: Option[Array[(Int, Int)]]): ShuffledRowRDD = { - new ShuffledRowRDD(shuffleDependency, readMetrics, partitionRanges) + def createShuffledRDD(partitionStartIndices: Option[Array[Int]]): ShuffledRowRDD = { + new ShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndices) } def createLocalShuffleRDD( @@ -100,14 +97,6 @@ case class ShuffleExchangeExec( new LocalShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndicesPerMapper) } - def createSkewedShuffleRDD( - partitionIndex: Int, - startMapIndex: Int, - endMapIndex: Int): SkewedShuffledRowRDD = { - new SkewedShuffledRowRDD(shuffleDependency, - partitionIndex, startMapIndex, endMapIndex, readMetrics) - } - /** * Caches the created ShuffleRowRDD so we can reuse that. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 6384aed6a78e0..62eea611556ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{PartialShuffleReaderExec, SkewedPartitionReaderExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -42,11 +41,17 @@ case class SortMergeJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isPartial: Boolean = false) extends BinaryExecNode with CodegenSupport { + isSkewJoin: Boolean = false) extends BinaryExecNode with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override def nodeName: String = { + if (isSkewJoin) super.nodeName + "(skew=true)" else super.nodeName + } + + override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator + override def simpleStringWithNodeId(): String = { val opId = ExplainUtils.getOpId(this) s"$nodeName $joinType ($opId)".trim @@ -98,7 +103,9 @@ case class SortMergeJoinExec( } override def requiredChildDistribution: Seq[Distribution] = { - if (isPartial) { + if (isSkewJoin) { + // We re-arrange the shuffle partitions to deal with skew join, and the new children + // partitioning doesn't satisfy `HashClusteredDistribution`. UnspecifiedDistribution :: UnspecifiedDistribution :: Nil } else { HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index 04b4d4f29f850..5565a0dd01840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql._ import org.apache.spark.sql.execution.adaptive._ -import org.apache.spark.sql.execution.adaptive.{CoalescedShuffleReaderExec, ReduceNumShufflePartitions} +import org.apache.spark.sql.execution.adaptive.CoalescedShuffleReaderExec import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -52,212 +52,6 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA } } - private def checkEstimation( - rule: ReduceNumShufflePartitions, - bytesByPartitionIdArray: Array[Array[Long]], - expectedPartitionStartIndices: Array[Int]): Unit = { - val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { - case (bytesByPartitionId, index) => - new MapOutputStatistics(index, bytesByPartitionId) - } - val estimatedPartitionStartIndices = - rule.estimatePartitionStartAndEndIndices(mapOutputStatistics).map(_._1) - assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) - } - - private def createReduceNumShufflePartitionsRule( - advisoryTargetPostShuffleInputSize: Long, - minNumPostShufflePartitions: Int = 1): ReduceNumShufflePartitions = { - val conf = new SQLConf().copy( - SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE -> advisoryTargetPostShuffleInputSize, - SQLConf.SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS -> minNumPostShufflePartitions) - ReduceNumShufflePartitions(conf) - } - - test("test estimatePartitionStartIndices - 1 Exchange") { - val rule = createReduceNumShufflePartitionsRule(100L) - - { - // All bytes per partition are 0. - val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // Some bytes per partition are 0 and total size is less than the target size. - // 1 post-shuffle partition is needed. - val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // 2 post-shuffle partitions are needed. - val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0) - val expectedPartitionStartIndices = Array[Int](0, 3) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // There are a few large pre-shuffle partitions. - val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // All pre-shuffle partitions are larger than the targeted size. - val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - - { - // The last pre-shuffle partition is in a single post-shuffle partition. - val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110) - val expectedPartitionStartIndices = Array[Int](0, 4) - checkEstimation(rule, Array(bytesByPartitionId), expectedPartitionStartIndices) - } - } - - test("test estimatePartitionStartIndices - 2 Exchanges") { - val rule = createReduceNumShufflePartitionsRule(100L) - - { - // If there are multiple values of the number of pre-shuffle partitions, - // we should see an assertion error. - val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) - val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) - val mapOutputStatistics = - Array( - new MapOutputStatistics(0, bytesByPartitionId1), - new MapOutputStatistics(1, bytesByPartitionId2)) - intercept[AssertionError](rule.estimatePartitionStartAndEndIndices( - mapOutputStatistics)) - } - - { - // All bytes per partition are 0. - val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) - val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // Some bytes per partition are 0. - // 1 post-shuffle partition is needed. - val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // 2 post-shuffle partition are needed. - val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) - val expectedPartitionStartIndices = Array[Int](0, 2, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // 4 post-shuffle partition are needed. - val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // 2 post-shuffle partition are needed. - val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // There are a few large pre-shuffle partitions. - val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // All pairs of pre-shuffle partitions are larger than the targeted size. - val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0) - val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110) - val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - } - - test("test estimatePartitionStartIndices and enforce minimal number of reducers") { - val rule = createReduceNumShufflePartitionsRule(100L, 2) - - { - // The minimal number of post-shuffle partitions is not enforced because - // the size of data is 0. - val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) - val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) - val expectedPartitionStartIndices = Array[Int](0) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // The minimal number of post-shuffle partitions is enforced. - val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20) - val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5) - val expectedPartitionStartIndices = Array[Int](0, 3) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - - { - // The number of post-shuffle partitions is determined by the coordinator. - val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20) - val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30) - val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4) - checkEstimation( - rule, - Array(bytesByPartitionId1, bytesByPartitionId2), - expectedPartitionStartIndices) - } - } - - /////////////////////////////////////////////////////////////////////////// - // Query tests - /////////////////////////////////////////////////////////////////////////// - val numInputPartitions: Int = 10 def withSparkSession( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala new file mode 100644 index 0000000000000..fcfde83b2ffd5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsCoalescerSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{MapOutputStatistics, SparkFunSuite} +import org.apache.spark.sql.execution.adaptive.ShufflePartitionsCoalescer + +class ShufflePartitionsCoalescerSuite extends SparkFunSuite { + + private def checkEstimation( + bytesByPartitionIdArray: Array[Array[Long]], + expectedPartitionStartIndices: Array[Int], + targetSize: Long, + minNumPartitions: Int = 1): Unit = { + val mapOutputStatistics = bytesByPartitionIdArray.zipWithIndex.map { + case (bytesByPartitionId, index) => + new MapOutputStatistics(index, bytesByPartitionId) + } + val estimatedPartitionStartIndices = ShufflePartitionsCoalescer.coalescePartitions( + mapOutputStatistics, + 0, + bytesByPartitionIdArray.head.length, + targetSize, + minNumPartitions) + assert(estimatedPartitionStartIndices === expectedPartitionStartIndices) + } + + test("1 shuffle") { + val targetSize = 100 + + { + // All bytes per partition are 0. + val bytesByPartitionId = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // Some bytes per partition are 0 and total size is less than the target size. + // 1 coalesced partition is expected. + val bytesByPartitionId = Array[Long](10, 0, 20, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // 2 coalesced partitions are expected. + val bytesByPartitionId = Array[Long](10, 0, 90, 20, 0) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // There are a few large shuffle partitions. + val bytesByPartitionId = Array[Long](110, 10, 100, 110, 0) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // All shuffle partitions are larger than the targeted size. + val bytesByPartitionId = Array[Long](100, 110, 100, 110, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + + { + // The last shuffle partition is in a single coalesced partition. + val bytesByPartitionId = Array[Long](30, 30, 0, 40, 110) + val expectedPartitionStartIndices = Array[Int](0, 4) + checkEstimation(Array(bytesByPartitionId), expectedPartitionStartIndices, targetSize) + } + } + + test("2 shuffles") { + val targetSize = 100 + + { + // If there are multiple values of the number of shuffle partitions, + // we should see an assertion error. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0, 0) + intercept[AssertionError] { + checkEstimation(Array(bytesByPartitionId1, bytesByPartitionId2), Array.empty, targetSize) + } + } + + { + // All bytes per partition are 0. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // Some bytes per partition are 0. + // 1 coalesced partition is expected. + val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 20, 0, 20) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // 2 coalesced partition are expected. + val bytesByPartitionId1 = Array[Long](0, 10, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 2, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // 4 coalesced partition are expected. + val bytesByPartitionId1 = Array[Long](0, 99, 0, 20, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // 2 coalesced partition are needed. + val bytesByPartitionId1 = Array[Long](0, 100, 0, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 70, 0, 30) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // There are a few large shuffle partitions. + val bytesByPartitionId1 = Array[Long](0, 100, 40, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 60, 0, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + + { + // All pairs of shuffle partitions are larger than the targeted size. + val bytesByPartitionId1 = Array[Long](100, 100, 40, 30, 0) + val bytesByPartitionId2 = Array[Long](30, 0, 60, 70, 110) + val expectedPartitionStartIndices = Array[Int](0, 1, 2, 3, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize) + } + } + + test("enforce minimal number of coalesced partitions") { + val targetSize = 100 + val minNumPartitions = 2 + + { + // The minimal number of coalesced partitions is not enforced because + // the size of data is 0. + val bytesByPartitionId1 = Array[Long](0, 0, 0, 0, 0) + val bytesByPartitionId2 = Array[Long](0, 0, 0, 0, 0) + val expectedPartitionStartIndices = Array[Int](0) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize, minNumPartitions) + } + + { + // The minimal number of coalesced partitions is enforced. + val bytesByPartitionId1 = Array[Long](10, 5, 5, 0, 20) + val bytesByPartitionId2 = Array[Long](5, 10, 0, 10, 5) + val expectedPartitionStartIndices = Array[Int](0, 3) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize, minNumPartitions) + } + + { + // The number of coalesced partitions is determined by the algorithm. + val bytesByPartitionId1 = Array[Long](10, 50, 20, 80, 20) + val bytesByPartitionId2 = Array[Long](40, 10, 0, 10, 30) + val expectedPartitionStartIndices = Array[Int](0, 1, 3, 4) + checkEstimation( + Array(bytesByPartitionId1, bytesByPartitionId2), + expectedPartitionStartIndices, + targetSize, minNumPartitions) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index a2071903bea7e..4edb35ea30fde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -23,7 +23,7 @@ import java.net.URI import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildRight, SortMergeJoinExec} import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.internal.SQLConf @@ -594,160 +594,84 @@ class AdaptiveQueryExecSuite .range(0, 1000, 1, 10) .selectExpr("id % 1 as key2", "id as value2") .createOrReplaceTempView("skewData2") - val (innerPlan, innerAdaptivePlan) = runAdaptiveAndVerifyResult( + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( "SELECT key1 FROM skewData1 join skewData2 ON key1 = key2 group by key1") - val innerSmj = findTopLevelSortMergeJoin(innerPlan) - assert(innerSmj.size == 1) // Additional shuffle introduced, so disable the "OptimizeSkewedJoin" optimization - val innerSmjAfter = findTopLevelSortMergeJoin(innerAdaptivePlan) - assert(innerSmjAfter.size == 1) + val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) + assert(innerSmj.size == 1 && !innerSmj.head.isSkewJoin) } } } + // TODO: we need a way to customize data distribution after shuffle, to improve test coverage + // of this case. test("SPARK-29544: adaptive skew join with different join types") { - Seq("false", "true").foreach { reducePostShufflePartitionsEnabled => - withSQLConf( - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", - SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD.key -> "100", - SQLConf.REDUCE_POST_SHUFFLE_PARTITIONS_ENABLED.key -> reducePostShufflePartitionsEnabled, - SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "700") { - withTempView("skewData1", "skewData2") { - spark - .range(0, 1000, 1, 10) - .selectExpr("id % 2 as key1", "id as value1") - .createOrReplaceTempView("skewData1") - spark - .range(0, 1000, 1, 10) - .selectExpr("id % 1 as key2", "id as value2") - .createOrReplaceTempView("skewData2") - // skewed inner join optimization - val (innerPlan, innerAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 join skewData2 ON key1 = key2") - val innerSmj = findTopLevelSortMergeJoin(innerPlan) - assert(innerSmj.size == 1) - // left stats: [3496, 0, 0, 0, 4014] - // right stats:[6292, 0, 0, 0, 0] - // the partition 0 in both left and right side are all skewed. - // And divide into 5 splits both in left and right (the max splits number). - // So there are 5 x 5 smjs for partition 0. - // Partition 4 in left side is skewed and is divided into 5 splits. - // The right side of partition 4 is not skewed. - // So there are 5 smjs for partition 4. - // So total (25 + 5 + 1) smjs. - // Union - // +- SortMergeJoin - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // . - // . - // . - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - - val innerSmjAfter = findTopLevelSortMergeJoin(innerAdaptivePlan) - assert(innerSmjAfter.size == 31) - - // skewed left outer join optimization - val (leftPlan, leftAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") - val leftSmj = findTopLevelSortMergeJoin(leftPlan) - assert(leftSmj.size == 1) - // left stats: [3496, 0, 0, 0, 4014] - // right stats:[6292, 0, 0, 0, 0] - // The partition 0 in both left and right are all skewed. - // The partition 4 in left side is skewed. - // But for left outer join, we don't split the right partition even skewed. - // So the partition 0 in left side is divided into 5 splits(the max split number). - // the partition 4 in left side is divided into 5 splits(the max split number). - // So total (5 + 5 + 1) smjs. - // Union - // +- SortMergeJoin - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // . - // . - // . - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - - val leftSmjAfter = findTopLevelSortMergeJoin(leftAdaptivePlan) - assert(leftSmjAfter.size == 11) - - // skewed right outer join optimization - val (rightPlan, rightAdaptivePlan) = runAdaptiveAndVerifyResult( - "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2") - val rightSmj = findTopLevelSortMergeJoin(rightPlan) - assert(rightSmj.size == 1) - // left stats: [3496, 0, 0, 0, 4014] - // right stats:[6292, 0, 0, 0, 0] - // The partition 0 in both left and right side are all skewed. - // And the partition 4 in left side is skewed. - // But for right outer join, we don't split the left partition even skewed. - // And divide right side into 5 splits(the max split number) - // So total 6 smjs. - // Union - // +- SortMergeJoin - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- CoalescedShuffleReader - // +- ShuffleQueryStage - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // . - // . - // . - // +- SortMergeJoin - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - // +- Sort - // +- SkewedShuffleReader - // +- ShuffleQueryStage - - val rightSmjAfter = findTopLevelSortMergeJoin(rightAdaptivePlan) - assert(rightSmjAfter.size == 6) + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_EXECUTION_SKEWED_PARTITION_SIZE_THRESHOLD.key -> "100", + SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "700") { + withTempView("skewData1", "skewData2") { + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 2 as key1", "id as value1") + .createOrReplaceTempView("skewData1") + spark + .range(0, 1000, 1, 10) + .selectExpr("id % 1 as key2", "id as value2") + .createOrReplaceTempView("skewData2") + + def checkSkewJoin(joins: Seq[SortMergeJoinExec], expectedNumPartitions: Int): Unit = { + assert(joins.size == 1 && joins.head.isSkewJoin) + assert(joins.head.left.collect { + case r: SkewJoinShuffleReaderExec => r + }.head.partitionSpecs.length == expectedNumPartitions) + assert(joins.head.right.collect { + case r: SkewJoinShuffleReaderExec => r + }.head.partitionSpecs.length == expectedNumPartitions) } + + // skewed inner join optimization + val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 join skewData2 ON key1 = key2") + // left stats: [3496, 0, 0, 0, 4014] + // right stats:[6292, 0, 0, 0, 0] + // Partition 0: both left and right sides are skewed, and divide into 5 splits, so + // 5 x 5 sub-partitions. + // Partition 1, 2, 3: not skewed, and coalesced into 1 partition. + // Partition 4: only left side is skewed, and divide into 5 splits, so + // 5 sub-partitions. + // So total (25 + 1 + 5) partitions. + val innerSmj = findTopLevelSortMergeJoin(innerAdaptivePlan) + checkSkewJoin(innerSmj, 25 + 1 + 5) + + // skewed left outer join optimization + val (_, leftAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 left outer join skewData2 ON key1 = key2") + // left stats: [3496, 0, 0, 0, 4014] + // right stats:[6292, 0, 0, 0, 0] + // Partition 0: both left and right sides are skewed, but left join can't split right side, + // so only left side is divided into 5 splits, and thus 5 sub-partitions. + // Partition 1, 2, 3: not skewed, and coalesced into 1 partition. + // Partition 4: only left side is skewed, and divide into 5 splits, so + // 5 sub-partitions. + // So total (5 + 1 + 5) partitions. + val leftSmj = findTopLevelSortMergeJoin(leftAdaptivePlan) + checkSkewJoin(leftSmj, 5 + 1 + 5) + + // skewed right outer join optimization + val (_, rightAdaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM skewData1 right outer join skewData2 ON key1 = key2") + // left stats: [3496, 0, 0, 0, 4014] + // right stats:[6292, 0, 0, 0, 0] + // Partition 0: both left and right sides are skewed, but right join can't split left side, + // so only right side is divided into 5 splits, and thus 5 sub-partitions. + // Partition 1, 2, 3: not skewed, and coalesced into 1 partition. + // Partition 4: only left side is skewed, but right join can't split left side, so just + // 1 partition. + // So total (5 + 1 + 1) partitions. + val rightSmj = findTopLevelSortMergeJoin(rightAdaptivePlan) + checkSkewJoin(rightSmj, 5 + 1 + 1) } } } @@ -805,3 +729,4 @@ class AdaptiveQueryExecSuite s" enabled but is not supported for"))) } } + From 859699135cb63b57f5d844e762070065cedb4408 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Feb 2020 11:17:27 -0800 Subject: [PATCH 038/185] [SPARK-30807][K8S][TESTS] Support Java 11 in K8S integration tests ### What changes were proposed in this pull request? This PR aims to support JDK11 test in K8S integration tests. - This is an update in testing framework instead of individual tests. - This will enable JDK11 runtime test when you didn't installed JDK11 on your local system. ### Why are the changes needed? Apache Spark 3.0.0 adds JDK11 support, but K8s integration tests use JDK8 until now. ### Does this PR introduce any user-facing change? No. This is a dev-only test-related PR. ### How was this patch tested? This is irrelevant to Jenkins UT, but Jenkins K8S IT (JDK8) should pass. - https://github.com/apache/spark/pull/27559#issuecomment-585903489 (JDK8 Passed) And, manually do the following for JDK11 test. ``` $ NO_MANUAL=1 ./dev/make-distribution.sh --r --pip --tgz -Phadoop-3.2 -Pkubernetes $ resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh --java-image-tag 11-jre-slim --spark-tgz $PWD/spark-*.tgz ``` ``` $ docker run -it --rm kubespark/spark:1318DD8A-2B15-4A00-BC69-D0E90CED235B /usr/local/openjdk-11/bin/java --version | tail -n1 OpenJDK 64-Bit Server VM 18.9 (build 11.0.6+10, mixed mode) ``` Closes #27559 from dongjoon-hyun/SPARK-30807. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../docker/src/main/dockerfiles/spark/Dockerfile | 3 ++- .../kubernetes/integration-tests/README.md | 15 +++++++++++++-- .../dev/dev-run-integration-tests.sh | 10 ++++++++++ .../kubernetes/integration-tests/pom.xml | 4 ++++ .../scripts/setup-integration-test-env.sh | 14 +++++++++++--- 5 files changed, 40 insertions(+), 6 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index a1fc63789bc61..6ed37fc637b31 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # +ARG java_image_tag=8-jre-slim -FROM openjdk:8-jre-slim +FROM openjdk:${java_image_tag} ARG spark_uid=185 diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index d7ad35a175a61..18b91916208d6 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -6,13 +6,17 @@ title: Spark on Kubernetes Integration Tests # Running the Kubernetes Integration Tests Note that the integration test framework is currently being heavily revised and -is subject to change. Note that currently the integration tests only run with Java 8. +is subject to change. The simplest way to run the integration tests is to install and run Minikube, then run the following from this directory: ./dev/dev-run-integration-tests.sh +To run tests with Java 11 instead of Java 8, use `--java-image-tag` to specify the base image. + + ./dev/dev-run-integration-tests.sh --java-image-tag 11-jre-slim + The minimum tested version of Minikube is 0.23.0. The kube-dns addon must be enabled. Minikube should run with a minimum of 4 CPUs and 6G of memory: @@ -183,7 +187,14 @@ to the wrapper scripts and using the wrapper scripts will simply set these appro A specific image tag to use, when set assumes images with those tags are already built and available in the specified image repository. When set to N/A (the default) fresh images will be built. - N/A + N/A + + + spark.kubernetes.test.javaImageTag + + A specific OpenJDK base image tag to use, when set uses it instead of 8-jre-slim. + + 8-jre-slim spark.kubernetes.test.imageTagFile diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 1f0a8035cea7b..76d6e1c1e8499 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -23,6 +23,7 @@ DEPLOY_MODE="minikube" IMAGE_REPO="docker.io/kubespark" SPARK_TGZ="N/A" IMAGE_TAG="N/A" +JAVA_IMAGE_TAG= BASE_IMAGE_NAME= JVM_IMAGE_NAME= PYTHON_IMAGE_NAME= @@ -52,6 +53,10 @@ while (( "$#" )); do IMAGE_TAG="$2" shift ;; + --java-image-tag) + JAVA_IMAGE_TAG="$2" + shift + ;; --deploy-mode) DEPLOY_MODE="$2" shift @@ -120,6 +125,11 @@ properties=( -Dtest.include.tags=$INCLUDE_TAGS ) +if [ -n "$JAVA_IMAGE_TAG" ]; +then + properties=( ${properties[@]} -Dspark.kubernetes.test.javaImageTag=$JAVA_IMAGE_TAG ) +fi + if [ -n $NAMESPACE ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.namespace=$NAMESPACE ) diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 8e1043f77db6d..369dfd491826c 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -39,6 +39,7 @@ ${project.build.directory}/spark-dist-unpacked N/A + 8-jre-slim ${project.build.directory}/imageTag.txt minikube docker.io/kubespark @@ -109,6 +110,9 @@ --image-tag ${spark.kubernetes.test.imageTag} + --java-image-tag + ${spark.kubernetes.test.javaImageTag} + --image-tag-output-file ${spark.kubernetes.test.imageTagFile} diff --git a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh index 9e04b963fc40e..ab906604fce06 100755 --- a/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh +++ b/resource-managers/kubernetes/integration-tests/scripts/setup-integration-test-env.sh @@ -23,6 +23,7 @@ IMAGE_TAG_OUTPUT_FILE="$TEST_ROOT_DIR/target/image-tag.txt" DEPLOY_MODE="minikube" IMAGE_REPO="docker.io/kubespark" IMAGE_TAG="N/A" +JAVA_IMAGE_TAG="8-jre-slim" SPARK_TGZ="N/A" # Parse arguments @@ -40,6 +41,10 @@ while (( "$#" )); do IMAGE_TAG="$2" shift ;; + --java-image-tag) + JAVA_IMAGE_TAG="$2" + shift + ;; --image-tag-output-file) IMAGE_TAG_OUTPUT_FILE="$2" shift @@ -82,6 +87,9 @@ then IMAGE_TAG=$(uuidgen); cd $SPARK_INPUT_DIR + # OpenJDK base-image tag (e.g. 8-jre-slim, 11-jre-slim) + JAVA_IMAGE_TAG_BUILD_ARG="-b java_image_tag=$JAVA_IMAGE_TAG" + # Build PySpark image LANGUAGE_BINDING_BUILD_ARGS="-p $DOCKER_FILE_BASE_PATH/bindings/python/Dockerfile" @@ -95,7 +103,7 @@ then case $DEPLOY_MODE in cloud) # Build images - $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $JAVA_IMAGE_TAG_BUILD_ARG $LANGUAGE_BINDING_BUILD_ARGS build # Push images appropriately if [[ $IMAGE_REPO == gcr.io* ]] ; @@ -109,13 +117,13 @@ then docker-for-desktop) # Only need to build as this will place it in our local Docker repo which is all # we need for Docker for Desktop to work so no need to also push - $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -r $IMAGE_REPO -t $IMAGE_TAG $JAVA_IMAGE_TAG_BUILD_ARG $LANGUAGE_BINDING_BUILD_ARGS build ;; minikube) # Only need to build and if we do this with the -m option for minikube we will # build the images directly using the minikube Docker daemon so no need to push - $SPARK_INPUT_DIR/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $LANGUAGE_BINDING_BUILD_ARGS build + $SPARK_INPUT_DIR/bin/docker-image-tool.sh -m -r $IMAGE_REPO -t $IMAGE_TAG $JAVA_IMAGE_TAG_BUILD_ARG $LANGUAGE_BINDING_BUILD_ARGS build ;; *) echo "Unrecognized deploy mode $DEPLOY_MODE" && exit 1 From 74cd46eb691be5dc1cb1c496eeeaa2614945bd98 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 13 Feb 2020 11:42:00 -0800 Subject: [PATCH 039/185] [SPARK-30816][K8S][TESTS] Fix dev-run-integration-tests.sh to ignore empty params ### What changes were proposed in this pull request? This PR aims to fix `dev-run-integration-tests.sh` to ignore empty params correctly. ### Why are the changes needed? The following script runs `mvn` integration test like the following. ``` $ resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh ... build/mvn integration-test -f /Users/dongjoon/APACHE/spark/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pscala-2.12 -Pkubernetes -Pkubernetes-integration-tests -Djava.version=8 -Dspark.kubernetes.test.sparkTgz=N/A -Dspark.kubernetes.test.imageTag=N/A -Dspark.kubernetes.test.imageRepo=docker.io/kubespark -Dspark.kubernetes.test.deployMode=minikube -Dtest.include.tags=k8s -Dspark.kubernetes.test.namespace= -Dspark.kubernetes.test.serviceAccountName= -Dspark.kubernetes.test.kubeConfigContext= -Dspark.kubernetes.test.master= -Dtest.exclude.tags= -Dspark.kubernetes.test.jvmImage=spark -Dspark.kubernetes.test.pythonImage=spark-py -Dspark.kubernetes.test.rImage=spark-r ``` After this PR, the empty parameters like the followings will be skipped like the original design. ``` -Dspark.kubernetes.test.namespace= -Dspark.kubernetes.test.serviceAccountName= -Dspark.kubernetes.test.kubeConfigContext= -Dspark.kubernetes.test.master= -Dtest.exclude.tags= ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Pass the Jenkins K8S integration test. Closes #27566 from dongjoon-hyun/SPARK-30816. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../integration-tests/dev/dev-run-integration-tests.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 76d6e1c1e8499..607bb243458a6 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -130,27 +130,27 @@ then properties=( ${properties[@]} -Dspark.kubernetes.test.javaImageTag=$JAVA_IMAGE_TAG ) fi -if [ -n $NAMESPACE ]; +if [ -n "$NAMESPACE" ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.namespace=$NAMESPACE ) fi -if [ -n $SERVICE_ACCOUNT ]; +if [ -n "$SERVICE_ACCOUNT" ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.serviceAccountName=$SERVICE_ACCOUNT ) fi -if [ -n $CONTEXT ]; +if [ -n "$CONTEXT" ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.kubeConfigContext=$CONTEXT ) fi -if [ -n $SPARK_MASTER ]; +if [ -n "$SPARK_MASTER" ]; then properties=( ${properties[@]} -Dspark.kubernetes.test.master=$SPARK_MASTER ) fi -if [ -n $EXCLUDE_TAGS ]; +if [ -n "$EXCLUDE_TAGS" ]; then properties=( ${properties[@]} -Dtest.exclude.tags=$EXCLUDE_TAGS ) fi From e2d3983de78f5c80fac066b7ee8bedd0987110dd Mon Sep 17 00:00:00 2001 From: Ali Afroozeh Date: Thu, 13 Feb 2020 23:58:55 +0100 Subject: [PATCH 040/185] [SPARK-30798][SQL] Scope Session.active in QueryExecution ### What changes were proposed in this pull request? This PR scopes `SparkSession.active` to prevent problems with processing queries with possibly different spark sessions (and different configs). A new method, `withActive` is introduced on `SparkSession` that restores the previous spark session after the block of code is executed. ### Why are the changes needed? `SparkSession.active` is a thread local variable that points to the current thread's spark session. It is important to note that the `SQLConf.get` method depends on `SparkSession.active`. In the current implementation it is possible that `SparkSession.active` points to a different session which causes various problems. Most of these problems arise because part of the query processing is done using the configurations of a different session. For example, when creating a data frame using a new session, i.e., `session.sql("...")`, part of the data frame is constructed using the currently active spark session, which can be a different session from the one used later for processing the query. ### Does this PR introduce any user-facing change? The `withActive` method is introduced on `SparkSession`. ### How was this patch tested? Unit tests (to be added) Closes #27387 from dbaliafroozeh/UseWithActiveSessionInQueryExecution. Authored-by: Ali Afroozeh Signed-off-by: herman --- .../apache/spark/sql/DataFrameWriter.scala | 2 +- .../apache/spark/sql/DataFrameWriterV2.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 36 ++++++++++--------- .../spark/sql/KeyValueGroupedDataset.scala | 5 +-- .../org/apache/spark/sql/SparkSession.scala | 30 +++++++++++----- .../spark/sql/execution/QueryExecution.scala | 16 +++++---- .../spark/sql/execution/SQLExecution.scala | 4 +-- .../streaming/MicroBatchExecution.scala | 4 +-- .../continuous/ContinuousExecution.scala | 2 +- .../spark/sql/internal/CatalogImpl.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 10 ++++++ .../apache/spark/sql/SQLQueryTestSuite.scala | 2 +- .../ui/SQLAppStatusListenerSuite.scala | 2 +- .../SparkExecuteStatementOperation.scala | 2 +- .../hive/thriftserver/SparkSQLDriver.scala | 2 +- .../hive/execution/HiveComparisonTest.scala | 3 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 17 files changed, 74 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 4557219abeb15..fff1f4b636dea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -896,7 +896,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def runCommand(session: SparkSession, name: String)(command: LogicalPlan): Unit = { val qe = session.sessionState.executePlan(command) // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd) + SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd) } private def lookupV2Provider(): Option[TableProvider] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index f5dd7613d4103..cf6bde5a2bcb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -226,7 +226,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) private def runCommand(name: String)(command: LogicalPlan): Unit = { val qe = sparkSession.sessionState.executePlan(command) // call `QueryExecution.toRDD` to trigger the execution of commands. - SQLExecution.withNewExecutionId(sparkSession, qe, Some(name))(qe.toRdd) + SQLExecution.withNewExecutionId(qe, Some(name))(qe.toRdd) } private def internalReplace(orCreate: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a1c33f92d17b4..42f35354e864f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -82,18 +82,19 @@ private[sql] object Dataset { dataset } - def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { - val qe = sparkSession.sessionState.executePlan(logicalPlan) - qe.assertAnalyzed() - new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) + def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = + sparkSession.withActive { + val qe = sparkSession.sessionState.executePlan(logicalPlan) + qe.assertAnalyzed() + new Dataset[Row](qe, RowEncoder(qe.analyzed.schema)) } /** A variant of ofRows that allows passing in a tracker so we can track query parsing time. */ def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan, tracker: QueryPlanningTracker) - : DataFrame = { + : DataFrame = sparkSession.withActive { val qe = new QueryExecution(sparkSession, logicalPlan, tracker) qe.assertAnalyzed() - new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema)) + new Dataset[Row](qe, RowEncoder(qe.analyzed.schema)) } } @@ -185,13 +186,12 @@ private[sql] object Dataset { */ @Stable class Dataset[T] private[sql]( - @transient private val _sparkSession: SparkSession, @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, @DeveloperApi @Unstable @transient val encoder: Encoder[T]) extends Serializable { @transient lazy val sparkSession: SparkSession = { - if (_sparkSession == null) { + if (queryExecution == null || queryExecution.sparkSession == null) { throw new SparkException( "Dataset transformations and actions can only be invoked by the driver, not inside of" + " other Dataset transformations; for example, dataset1.map(x => dataset2.values.count()" + @@ -199,7 +199,7 @@ class Dataset[T] private[sql]( "performed inside of the dataset1.map transformation. For more information," + " see SPARK-28702.") } - _sparkSession + queryExecution.sparkSession } // A globally unique id of this Dataset. @@ -211,7 +211,7 @@ class Dataset[T] private[sql]( // you wrap it with `withNewExecutionId` if this actions doesn't call other action. def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { - this(sparkSession, sparkSession.sessionState.executePlan(logicalPlan), encoder) + this(sparkSession.sessionState.executePlan(logicalPlan), encoder) } def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: Encoder[T]) = { @@ -445,7 +445,7 @@ class Dataset[T] private[sql]( */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = new Dataset[Row](sparkSession, queryExecution, RowEncoder(schema)) + def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder(schema)) /** * Returns a new Dataset where each record has been mapped on to the specified type. The @@ -503,7 +503,9 @@ class Dataset[T] private[sql]( * @group basic * @since 1.6.0 */ - def schema: StructType = queryExecution.analyzed.schema + def schema: StructType = sparkSession.withActive { + queryExecution.analyzed.schema + } /** * Prints the schema to the console in a nice tree format. @@ -539,7 +541,7 @@ class Dataset[T] private[sql]( * @group basic * @since 3.0.0 */ - def explain(mode: String): Unit = { + def explain(mode: String): Unit = sparkSession.withActive { // Because temporary views are resolved during analysis when we create a Dataset, and // `ExplainCommand` analyzes input query plan and resolves temporary views again. Using // `ExplainCommand` here will probably output different query plans, compared to the results @@ -1502,7 +1504,7 @@ class Dataset[T] private[sql]( val namedColumns = columns.map(_.withInputType(exprEnc, logicalPlan.output).named) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) - new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) + new Dataset(execution, ExpressionEncoder.tuple(encoders)) } /** @@ -3472,7 +3474,7 @@ class Dataset[T] private[sql]( * an execution. */ private def withNewExecutionId[U](body: => U): U = { - SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) + SQLExecution.withNewExecutionId(queryExecution)(body) } /** @@ -3481,7 +3483,7 @@ class Dataset[T] private[sql]( * reset. */ private def withNewRDDExecutionId[U](body: => U): U = { - SQLExecution.withNewExecutionId(sparkSession, rddQueryExecution) { + SQLExecution.withNewExecutionId(rddQueryExecution) { rddQueryExecution.executedPlan.resetMetrics() body } @@ -3492,7 +3494,7 @@ class Dataset[T] private[sql]( * user-registered callback functions. */ private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = { - SQLExecution.withNewExecutionId(sparkSession, qe, Some(name)) { + SQLExecution.withNewExecutionId(qe, Some(name)) { qe.executedPlan.resetMetrics() action(qe.executedPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 89cc9735e4f6a..76ee297dfca79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -449,10 +449,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sparkSession, aggregate) - new Dataset( - sparkSession, - execution, - ExpressionEncoder.tuple(kExprEnc +: encoders)) + new Dataset(execution, ExpressionEncoder.tuple(kExprEnc +: encoders)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index abefb348cafc7..1fb97fb4b4cf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -293,8 +293,7 @@ class SparkSession private( * * @since 2.0.0 */ - def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SparkSession.setActiveSession(this) + def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = withActive { val encoder = Encoders.product[A] Dataset.ofRows(self, ExternalRDD(rdd, self)(encoder)) } @@ -304,8 +303,7 @@ class SparkSession private( * * @since 2.0.0 */ - def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SparkSession.setActiveSession(this) + def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = withActive { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] val attributeSeq = schema.toAttributes Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) @@ -343,7 +341,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = withActive { // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val encoder = RowEncoder(schema) @@ -373,7 +371,7 @@ class SparkSession private( * @since 2.0.0 */ @DeveloperApi - def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { + def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = withActive { Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) } @@ -385,7 +383,7 @@ class SparkSession private( * * @since 2.0.0 */ - def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = withActive { val attributeSeq: Seq[AttributeReference] = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => @@ -414,7 +412,7 @@ class SparkSession private( * SELECT * queries will return the columns in an undefined order. * @since 1.6.0 */ - def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { + def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = withActive { val attrSeq = getSchema(beanClass) val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq) Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq)) @@ -599,7 +597,7 @@ class SparkSession private( * * @since 2.0.0 */ - def sql(sqlText: String): DataFrame = { + def sql(sqlText: String): DataFrame = withActive { val tracker = new QueryPlanningTracker val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) { sessionState.sqlParser.parsePlan(sqlText) @@ -751,6 +749,20 @@ class SparkSession private( } } + /** + * Execute a block of code with the this session set as the active session, and restore the + * previous session on completion. + */ + private[sql] def withActive[T](block: => T): T = { + // Use the active session thread local directly to make sure we get the session that is actually + // set and not the default session. This to prevent that we promote the default session to the + // active session once we are done. + val old = SparkSession.activeThreadSession.get() + SparkSession.setActiveSession(this) + try block finally { + SparkSession.setActiveSession(old) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 38ef66682c413..53b6b5d82c021 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -63,13 +63,12 @@ class QueryExecution( } } - lazy val analyzed: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.ANALYSIS) { - SparkSession.setActiveSession(sparkSession) + lazy val analyzed: LogicalPlan = executePhase(QueryPlanningTracker.ANALYSIS) { // We can't clone `logical` here, which will reset the `_analyzed` flag. sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) } - lazy val withCachedData: LogicalPlan = { + lazy val withCachedData: LogicalPlan = sparkSession.withActive { assertAnalyzed() assertSupported() // clone the plan to avoid sharing the plan instance between different stages like analyzing, @@ -77,20 +76,20 @@ class QueryExecution( sparkSession.sharedState.cacheManager.useCachedData(analyzed.clone()) } - lazy val optimizedPlan: LogicalPlan = tracker.measurePhase(QueryPlanningTracker.OPTIMIZATION) { + lazy val optimizedPlan: LogicalPlan = executePhase(QueryPlanningTracker.OPTIMIZATION) { // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. sparkSession.sessionState.optimizer.executeAndTrack(withCachedData.clone(), tracker) } - lazy val sparkPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { + lazy val sparkPlan: SparkPlan = executePhase(QueryPlanningTracker.PLANNING) { // Clone the logical plan here, in case the planner rules change the states of the logical plan. QueryExecution.createSparkPlan(sparkSession, planner, optimizedPlan.clone()) } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = tracker.measurePhase(QueryPlanningTracker.PLANNING) { + lazy val executedPlan: SparkPlan = executePhase(QueryPlanningTracker.PLANNING) { // clone the plan to avoid sharing the plan instance between different stages like analyzing, // optimizing and planning. QueryExecution.prepareForExecution(preparations, sparkPlan.clone()) @@ -116,6 +115,10 @@ class QueryExecution( QueryExecution.preparations(sparkSession) } + private def executePhase[T](phase: String)(block: => T): T = sparkSession.withActive { + tracker.measurePhase(phase)(block) + } + def simpleString: String = simpleString(false) def simpleString(formatted: Boolean): String = withRedaction { @@ -305,7 +308,6 @@ object QueryExecution { sparkSession: SparkSession, planner: SparkPlanner, plan: LogicalPlan): SparkPlan = { - SparkSession.setActiveSession(sparkSession) // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. planner.plan(ReturnAnswer(plan)).next() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 9f177819f6ea7..59c503e372535 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -62,9 +62,9 @@ object SQLExecution { * we can connect them with an execution. */ def withNewExecutionId[T]( - sparkSession: SparkSession, queryExecution: QueryExecution, - name: Option[String] = None)(body: => T): T = { + name: Option[String] = None)(body: => T): T = queryExecution.sparkSession.withActive { + val sparkSession = queryExecution.sparkSession val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY) val executionId = SQLExecution.nextExecutionId diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 83bc347e23ed4..45a2ce16183a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -563,11 +563,11 @@ class MicroBatchExecution( } val nextBatch = - new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema)) + new Dataset(lastExecution, RowEncoder(lastExecution.analyzed.schema)) val batchSinkProgress: Option[StreamWriterCommitProgress] = reportTimeTaken("addBatch") { - SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { + SQLExecution.withNewExecutionId(lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) case _: SupportsWrite => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index a9b724a73a18e..a109c2171f3d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -252,7 +252,7 @@ class ContinuousExecution( updateStatusMessage("Running") reportTimeTaken("runContinuous") { - SQLExecution.withNewExecutionId(sparkSessionForQuery, lastExecution) { + SQLExecution.withNewExecutionId(lastExecution) { lastExecution.executedPlan.execute() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 3740b56cb9cbb..d3ef03e9b3b74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -520,7 +520,7 @@ private[sql] object CatalogImpl { val encoded = data.map(d => enc.toRow(d).copy()) val plan = new LocalRelation(enc.schema.toAttributes, encoded) val queryExecution = sparkSession.sessionState.executePlan(plan) - new Dataset[T](sparkSession, queryExecution, enc) + new Dataset[T](queryExecution, enc) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 233d67898f909..b0bd612e88d98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1899,6 +1899,16 @@ class DatasetSuite extends QueryTest val e = intercept[AnalysisException](spark.range(1).tail(-1)) e.getMessage.contains("tail expression must be equal to or greater than 0") } + + test("SparkSession.active should be the same instance after dataset operations") { + val active = SparkSession.getActiveSession.get + val clone = active.cloneSession() + val ds = new Dataset(clone, spark.range(10).queryExecution.logical, Encoders.INT) + + ds.queryExecution.analyzed + + assert(active eq SparkSession.getActiveSession.get) + } } object AssertExecutionId { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index da4727f6a98cb..83285911b3948 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -511,7 +511,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession { val df = session.sql(sql) val schema = df.schema.catalogString // Get answer, but also get rid of the #1234 expression ids that show up in explain plans - val answer = SQLExecution.withNewExecutionId(session, df.queryExecution, Some(sql)) { + val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) { hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index fdfd392a224cb..d18a35c3110f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -508,7 +508,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils override lazy val executedPlan = physicalPlan } - SQLExecution.withNewExecutionId(spark, dummyQueryExecution) { + SQLExecution.withNewExecutionId(dummyQueryExecution) { physicalPlan.execute().collect() } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 76d07848f79a9..cf0e5ebf3a2b1 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -295,7 +295,7 @@ private[hive] class SparkExecuteStatementOperation( resultList.get.iterator } } - dataTypes = result.queryExecution.analyzed.output.map(_.dataType).toArray + dataTypes = result.schema.fields.map(_.dataType) } catch { // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index 362ac362e9718..12fba0eae6dce 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -61,7 +61,7 @@ private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlCont try { context.sparkContext.setJobDescription(command) val execution = context.sessionState.executePlan(context.sql(command).logicalPlan) - hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) { + hiveResponse = SQLExecution.withNewExecutionId(execution) { hiveResultString(execution.executedPlan) } tableSchema = getResultSetSchema(execution) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 28e1db961f611..8b1f4c92755b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -346,8 +346,7 @@ abstract class HiveComparisonTest val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath)) def getResult(): Seq[String] = { - SQLExecution.withNewExecutionId( - query.sparkSession, query)(hiveResultString(query.executedPlan)) + SQLExecution.withNewExecutionId(query)(hiveResultString(query.executedPlan)) } try { (query, prepareAnswer(query, getResult())) } catch { case e: Throwable => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index cc4592a5caf68..222244a04f5f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -501,7 +501,7 @@ private[hive] class TestHiveSparkSession( // has already set the execution id. if (sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) == null) { // We don't actually have a `QueryExecution` here, use a fake one instead. - SQLExecution.withNewExecutionId(this, new QueryExecution(this, OneRowRelation())) { + SQLExecution.withNewExecutionId(new QueryExecution(this, OneRowRelation())) { createCmds.foreach(_()) } } else { From 57254c9719f9af9ad985596ed7fbbaafa4052002 Mon Sep 17 00:00:00 2001 From: sarthfrey-db Date: Thu, 13 Feb 2020 16:15:00 -0800 Subject: [PATCH 041/185] [SPARK-30667][CORE] Add allGather method to BarrierTaskContext ### What changes were proposed in this pull request? The `allGather` method is added to the `BarrierTaskContext`. This method contains the same functionality as the `BarrierTaskContext.barrier` method; it blocks the task until all tasks make the call, at which time they may continue execution. In addition, the `allGather` method takes an input message. Upon returning from the `allGather` the task receives a list of all the messages sent by all the tasks that made the `allGather` call. ### Why are the changes needed? There are many situations where having the tasks communicate in a synchronized way is useful. One simple example is if each task needs to start a server to serve requests from one another; first the tasks must find a free port (the result of which is undetermined beforehand) and then start making requests, but to do so they each must know the port chosen by the other task. An `allGather` method would allow them to inform each other of the port they will run on. ### Does this PR introduce any user-facing change? Yes, an `BarrierTaskContext.allGather` method will be available through the Scala, Java, and Python APIs. ### How was this patch tested? Most of the code path is already covered by tests to the `barrier` method, since this PR includes a refactor so that much code is shared by the `barrier` and `allGather` methods. However, a test is added to assert that an all gather on each tasks partition ID will return a list of every partition ID. An example through the Python API: ```python >>> from pyspark import BarrierTaskContext >>> >>> def f(iterator): ... context = BarrierTaskContext.get() ... return [context.allGather('{}'.format(context.partitionId()))] ... >>> sc.parallelize(range(4), 4).barrier().mapPartitions(f).collect()[0] [u'3', u'1', u'0', u'2'] ``` Closes #27395 from sarthfrey/master. Lead-authored-by: sarthfrey-db Co-authored-by: sarthfrey Signed-off-by: Xiangrui Meng --- .../org/apache/spark/BarrierCoordinator.scala | 113 +++++++++++-- .../org/apache/spark/BarrierTaskContext.scala | 153 ++++++++++++------ .../spark/api/python/PythonRunner.scala | 51 ++++-- .../scheduler/BarrierTaskContextSuite.scala | 74 +++++++++ python/pyspark/taskcontext.py | 49 +++++- python/pyspark/tests/test_taskcontext.py | 20 +++ 6 files changed, 381 insertions(+), 79 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 4e417679ca663..042a2664a0e27 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -17,12 +17,17 @@ package org.apache.spark +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer import scala.collection.mutable.ArrayBuffer +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} @@ -99,10 +104,15 @@ private[spark] class BarrierCoordinator( // reset when a barrier() call fails due to timeout. private var barrierEpoch: Int = 0 - // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() - // call. + // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) + // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call + private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer] + + // The blocking requestMethod called by tasks to sync up for this stage attempt + private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER + // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -130,9 +140,32 @@ private[spark] class BarrierCoordinator( // Process the global sync request. The barrier() call succeed if collected enough requests // within a configured time, otherwise fail all the pending requests. - def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { + def handleRequest( + requester: RpcCallContext, + request: RequestToSync + ): Unit = synchronized { val taskId = request.taskAttemptId val epoch = request.barrierEpoch + val requestMethod = request.requestMethod + val partitionId = request.partitionId + val allGatherMessage = request match { + case ag: AllGatherRequestToSync => ag.allGatherMessage + case _ => "" + } + + if (requesters.size == 0) { + requestMethodToSync = requestMethod + } + + if (requestMethodToSync != requestMethod) { + requesters.foreach( + _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + + s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " + + s"the current synchronized requestMethod `$requestMethodToSync`" + )) + ) + cleanupBarrierStage(barrierId) + } // Require the number of tasks is correctly set from the BarrierTaskContext. require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + @@ -153,6 +186,7 @@ private[spark] class BarrierCoordinator( } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester + allGatherMessages(partitionId) = allGatherMessage logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (maybeFinishAllRequesters(requesters, numTasks)) { @@ -162,6 +196,7 @@ private[spark] class BarrierCoordinator( s"tasks, finished successfully.") barrierEpoch += 1 requesters.clear() + allGatherMessages.clear() cancelTimerTask() } } @@ -173,7 +208,13 @@ private[spark] class BarrierCoordinator( requesters: ArrayBuffer[RpcCallContext], numTasks: Int): Boolean = { if (requesters.size == numTasks) { - requesters.foreach(_.reply(())) + requestMethodToSync match { + case RequestMethod.BARRIER => + requesters.foreach(_.reply("")) + case RequestMethod.ALL_GATHER => + val json: String = compact(render(allGatherMessages)) + requesters.foreach(_.reply(json)) + } true } else { false @@ -186,6 +227,7 @@ private[spark] class BarrierCoordinator( // messages come from current stage attempt shall fail. barrierEpoch = -1 requesters.clear() + allGatherMessages.clear() cancelTimerTask() } } @@ -199,11 +241,11 @@ private[spark] class BarrierCoordinator( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => + case request: RequestToSync => // Get or init the ContextBarrierState correspond to the stage attempt. - val barrierId = ContextBarrierId(stageId, stageAttemptId) + val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId) states.computeIfAbsent(barrierId, - (key: ContextBarrierId) => new ContextBarrierState(key, numTasks)) + (key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks)) val barrierState = states.get(barrierId) barrierState.handleRequest(context, request) @@ -216,6 +258,16 @@ private[spark] class BarrierCoordinator( private[spark] sealed trait BarrierCoordinatorMessage extends Serializable +private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { + def numTasks: Int + def stageId: Int + def stageAttemptId: Int + def taskAttemptId: Long + def barrierEpoch: Int + def partitionId: Int + def requestMethod: RequestMethod.Value +} + /** * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is * identified by stageId + stageAttemptId + barrierEpoch. @@ -224,11 +276,44 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable * @param stageId ID of current stage * @param stageAttemptId ID of current stage attempt * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls + * @param partitionId ID of the current partition the task is assigned to + * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator */ -private[spark] case class RequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int) extends BarrierCoordinatorMessage +private[spark] case class BarrierRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value +) extends RequestToSync + +/** + * A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is + * identified by stageId + stageAttemptId + barrierEpoch. + * + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive + * @param stageId ID of current stage + * @param stageAttemptId ID of current stage attempt + * @param taskAttemptId Unique ID of current task + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls + * @param partitionId ID of the current partition the task is assigned to + * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator + * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER + */ +private[spark] case class AllGatherRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value, + allGatherMessage: String +) extends RequestToSync + +private[spark] object RequestMethod extends Enumeration { + val BARRIER, ALL_GATHER = Value +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 3d369802f3023..2263538a11676 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,11 +17,19 @@ package org.apache.spark +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Properties, Timer, TimerTask} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.concurrent.TimeoutException import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.json4s.DefaultFormats +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.parse import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics @@ -59,49 +67,31 @@ class BarrierTaskContext private[spark] ( // from different tasks within the same barrier stage attempt to succeed. private lazy val numTasks = getTaskInfos().size - /** - * :: Experimental :: - * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to - * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same - * stage have reached this routine. - * - * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all - * possible code branches. Otherwise, you may get the job hanging or a SparkException after - * timeout. Some examples of '''misuses''' are listed below: - * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it - * shall lead to timeout of the function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * if (context.partitionId() == 0) { - * // Do nothing. - * } else { - * context.barrier() - * } - * iter - * } - * }}} - * - * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the - * second function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * try { - * // Do something that might throw an Exception. - * doSomething() - * context.barrier() - * } catch { - * case e: Exception => logWarning("...", e) - * } - * context.barrier() - * iter - * } - * }}} - */ - @Experimental - @Since("2.4.0") - def barrier(): Unit = { + private def getRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptNumber: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value, + allGatherMessage: String + ): RequestToSync = { + requestMethod match { + case RequestMethod.BARRIER => + BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch, partitionId, requestMethod) + case RequestMethod.ALL_GATHER => + AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch, partitionId, requestMethod, allGatherMessage) + } + } + + private def runBarrier( + requestMethod: RequestMethod.Value, + allGatherMessage: String = "" + ): String = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") logTrace("Current callSite: " + Utils.getCallSite()) @@ -118,10 +108,12 @@ class BarrierTaskContext private[spark] ( // Log the update of global sync every 60 seconds. timer.schedule(timerTask, 60000, 60000) + var json: String = "" + try { - val abortableRpcFuture = barrierCoordinator.askAbortable[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch), + val abortableRpcFuture = barrierCoordinator.askAbortable[String]( + message = getRequestToSync(numTasks, stageId, stageAttemptNumber, + taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. timeout = new RpcTimeout(365.days, "barrierTimeout")) @@ -133,7 +125,7 @@ class BarrierTaskContext private[spark] ( while (!abortableRpcFuture.toFuture.isCompleted) { // wait RPC future for at most 1 second try { - ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) + json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) } catch { case _: TimeoutException | _: InterruptedException => // If `TimeoutException` thrown, waiting RPC future reach 1 second. @@ -163,6 +155,73 @@ class BarrierTaskContext private[spark] ( timerTask.cancel() timer.purge() } + json + } + + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of '''misuses''' are listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit = { + runBarrier(RequestMethod.BARRIER) + () + } + + /** + * :: Experimental :: + * Blocks until all tasks in the same stage have reached this routine. Each task passes in + * a message and returns with a list of all the messages passed in by each of those tasks. + * + * CAUTION! The allGather method requires the same precautions as the barrier method + * + * The message is type String rather than Array[Byte] because it is more convenient for + * the user at the cost of worse performance. + */ + @Experimental + @Since("3.0.0") + def allGather(message: String): ArrayBuffer[String] = { + val json = runBarrier(RequestMethod.ALL_GATHER, message) + val jsonArray = parse(json) + implicit val formats = DefaultFormats + ArrayBuffer(jsonArray.extract[Array[String]]: _*) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 658e0d593a167..fa8bf0fc06358 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,8 +24,13 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} @@ -238,13 +243,18 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock.setSoTimeout(10000) authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) - input.readInt() match { + val requestMethod = input.readInt() + // The BarrierTaskContext function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + requestMethod match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - // The barrier() function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(0) - barrierAndServe(sock) - + barrierAndServe(requestMethod, sock) + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val length = input.readInt() + val message = new Array[Byte](length) + input.readFully(message) + barrierAndServe(requestMethod, sock, new String(message, UTF_8)) case _ => val out = new DataOutputStream(new BufferedOutputStream( sock.getOutputStream)) @@ -395,15 +405,31 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } /** - * Gateway to call BarrierTaskContext.barrier(). + * Gateway to call BarrierTaskContext methods. */ - def barrierAndServe(sock: Socket): Unit = { - require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") - + def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { + require( + serverSocket.isDefined, + "No available ServerSocket to redirect the BarrierTaskContext method call." + ) val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { - context.asInstanceOf[BarrierTaskContext].barrier() - writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) + var result: String = "" + requestMethod match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + context.asInstanceOf[BarrierTaskContext].barrier() + result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather( + message + ) + result = compact(render(JArray( + messages.map( + (message) => JString(message) + ).toList + ))) + } + writeUTF(result, out) } catch { case e: SparkException => writeUTF(e.getMessage, out) @@ -638,6 +664,7 @@ private[spark] object SpecialLengths { private[spark] object BarrierTaskContextMessageProtocol { val BARRIER_FUNCTION = 1 + val ALL_GATHER_FUNCTION = 2 val BARRIER_RESULT_SUCCESS = "success" val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index fc8ac38479932..ed38b7f0ecac1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.File +import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ @@ -52,6 +53,79 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { assert(times.max - times.min <= 1000) } + test("share messages with allGather() call") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message = context.partitionId().toString + val messages = context.allGather(message) + messages.toList.iterator + } + // Take a sorted list of all the partitionId messages + val messages = rdd2.collect().head + // All the task partitionIds are shared + for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString) + } + + test("throw exception if we attempt to synchronize with different blocking calls") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + val partitionId = context.partitionId + if (partitionId == 0) { + context.barrier() + } else { + context.allGather(partitionId.toString) + } + Seq(null).iterator + } + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("does not match the current synchronized requestMethod")) + } + + test("successively sync with allGather and barrier") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message = context.partitionId().toString + val messages = context.allGather(message) + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 1000) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 1000) + } + test("support multiple barrier() call within a single task") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index d648f63338514..90bd2345ac525 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,9 +16,10 @@ # from __future__ import print_function +import json from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import write_int, UTF8Deserializer +from pyspark.serializers import write_int, write_with_length, UTF8Deserializer class TaskContext(object): @@ -107,18 +108,28 @@ def resources(self): BARRIER_FUNCTION = 1 +ALL_GATHER_FUNCTION = 2 -def _load_from_socket(port, auth_secret): +def _load_from_socket(port, auth_secret, function, all_gather_message=None): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) - # The barrier() call may block forever, so no timeout + + # The call may block forever, so no timeout sock.settimeout(None) - # Make a barrier() function call. - write_int(BARRIER_FUNCTION, sockfile) + + if function == BARRIER_FUNCTION: + # Make a barrier() function call. + write_int(function, sockfile) + elif function == ALL_GATHER_FUNCTION: + # Make a all_gather() function call. + write_int(function, sockfile) + write_with_length(all_gather_message.encode("utf-8"), sockfile) + else: + raise ValueError("Unrecognized function type") sockfile.flush() # Collect result. @@ -199,7 +210,33 @@ def barrier(self): raise Exception("Not supported to call barrier() before initialize " + "BarrierTaskContext.") else: - _load_from_socket(self._port, self._secret) + _load_from_socket(self._port, self._secret, BARRIER_FUNCTION) + + def allGather(self, message=""): + """ + .. note:: Experimental + + This function blocks until all tasks in the same stage have reached this routine. + Each task passes in a message and returns with a list of all the messages passed in + by each of those tasks. + + .. warning:: In a barrier stage, each task much have the same number of `allGather()` + calls, in all possible code branches. + Otherwise, you may get the job hanging or a SparkException after timeout. + """ + if not isinstance(message, str): + raise ValueError("Argument `message` must be of type `str`") + elif self._port is None or self._secret is None: + raise Exception("Not supported to call barrier() before initialize " + + "BarrierTaskContext.") + else: + gathered_items = _load_from_socket( + self._port, + self._secret, + ALL_GATHER_FUNCTION, + message, + ) + return [e for e in json.loads(gathered_items)] def getTaskInfos(self): """ diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 68cfe814762e0..0053aadd9c3ed 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -135,6 +135,26 @@ def context_barrier(x): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) + def test_all_gather(self): + """ + Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks + within a stage and passes messages properly. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + out = tc.allGather(str(context.partitionId())) + pids = [int(e) for e in out] + return [pids] + + pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0] + self.assertTrue(pids == [0, 1, 2, 3]) + def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the From fa3517cdb163b0589dc02c7d1fefb65be811f65f Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 13 Feb 2020 17:43:55 -0800 Subject: [PATCH 042/185] Revert "[SPARK-30667][CORE] Add allGather method to BarrierTaskContext" This reverts commit 57254c9719f9af9ad985596ed7fbbaafa4052002. --- .../org/apache/spark/BarrierCoordinator.scala | 113 ++----------- .../org/apache/spark/BarrierTaskContext.scala | 153 ++++++------------ .../spark/api/python/PythonRunner.scala | 51 ++---- .../scheduler/BarrierTaskContextSuite.scala | 74 --------- python/pyspark/taskcontext.py | 49 +----- python/pyspark/tests/test_taskcontext.py | 20 --- 6 files changed, 79 insertions(+), 381 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 042a2664a0e27..4e417679ca663 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -17,17 +17,12 @@ package org.apache.spark -import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer import scala.collection.mutable.ArrayBuffer -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} - import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} @@ -104,15 +99,10 @@ private[spark] class BarrierCoordinator( // reset when a barrier() call fails due to timeout. private var barrierEpoch: Int = 0 - // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call + // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() + // call. private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) - // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call - private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer] - - // The blocking requestMethod called by tasks to sync up for this stage attempt - private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER - // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -140,32 +130,9 @@ private[spark] class BarrierCoordinator( // Process the global sync request. The barrier() call succeed if collected enough requests // within a configured time, otherwise fail all the pending requests. - def handleRequest( - requester: RpcCallContext, - request: RequestToSync - ): Unit = synchronized { + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { val taskId = request.taskAttemptId val epoch = request.barrierEpoch - val requestMethod = request.requestMethod - val partitionId = request.partitionId - val allGatherMessage = request match { - case ag: AllGatherRequestToSync => ag.allGatherMessage - case _ => "" - } - - if (requesters.size == 0) { - requestMethodToSync = requestMethod - } - - if (requestMethodToSync != requestMethod) { - requesters.foreach( - _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + - s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " + - s"the current synchronized requestMethod `$requestMethodToSync`" - )) - ) - cleanupBarrierStage(barrierId) - } // Require the number of tasks is correctly set from the BarrierTaskContext. require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + @@ -186,7 +153,6 @@ private[spark] class BarrierCoordinator( } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester - allGatherMessages(partitionId) = allGatherMessage logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (maybeFinishAllRequesters(requesters, numTasks)) { @@ -196,7 +162,6 @@ private[spark] class BarrierCoordinator( s"tasks, finished successfully.") barrierEpoch += 1 requesters.clear() - allGatherMessages.clear() cancelTimerTask() } } @@ -208,13 +173,7 @@ private[spark] class BarrierCoordinator( requesters: ArrayBuffer[RpcCallContext], numTasks: Int): Boolean = { if (requesters.size == numTasks) { - requestMethodToSync match { - case RequestMethod.BARRIER => - requesters.foreach(_.reply("")) - case RequestMethod.ALL_GATHER => - val json: String = compact(render(allGatherMessages)) - requesters.foreach(_.reply(json)) - } + requesters.foreach(_.reply(())) true } else { false @@ -227,7 +186,6 @@ private[spark] class BarrierCoordinator( // messages come from current stage attempt shall fail. barrierEpoch = -1 requesters.clear() - allGatherMessages.clear() cancelTimerTask() } } @@ -241,11 +199,11 @@ private[spark] class BarrierCoordinator( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case request: RequestToSync => + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => // Get or init the ContextBarrierState correspond to the stage attempt. - val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId) + val barrierId = ContextBarrierId(stageId, stageAttemptId) states.computeIfAbsent(barrierId, - (key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks)) + (key: ContextBarrierId) => new ContextBarrierState(key, numTasks)) val barrierState = states.get(barrierId) barrierState.handleRequest(context, request) @@ -258,16 +216,6 @@ private[spark] class BarrierCoordinator( private[spark] sealed trait BarrierCoordinatorMessage extends Serializable -private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { - def numTasks: Int - def stageId: Int - def stageAttemptId: Int - def taskAttemptId: Long - def barrierEpoch: Int - def partitionId: Int - def requestMethod: RequestMethod.Value -} - /** * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is * identified by stageId + stageAttemptId + barrierEpoch. @@ -276,44 +224,11 @@ private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { * @param stageId ID of current stage * @param stageAttemptId ID of current stage attempt * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls - * @param partitionId ID of the current partition the task is assigned to - * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. */ -private[spark] case class BarrierRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value -) extends RequestToSync - -/** - * A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is - * identified by stageId + stageAttemptId + barrierEpoch. - * - * @param numTasks The number of global sync requests the BarrierCoordinator shall receive - * @param stageId ID of current stage - * @param stageAttemptId ID of current stage attempt - * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls - * @param partitionId ID of the current partition the task is assigned to - * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator - * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER - */ -private[spark] case class AllGatherRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value, - allGatherMessage: String -) extends RequestToSync - -private[spark] object RequestMethod extends Enumeration { - val BARRIER, ALL_GATHER = Value -} +private[spark] case class RequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int) extends BarrierCoordinatorMessage diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 2263538a11676..3d369802f3023 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,19 +17,11 @@ package org.apache.spark -import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Properties, Timer, TimerTask} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.concurrent.TimeoutException import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.json4s.DefaultFormats -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.parse import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics @@ -67,31 +59,49 @@ class BarrierTaskContext private[spark] ( // from different tasks within the same barrier stage attempt to succeed. private lazy val numTasks = getTaskInfos().size - private def getRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptNumber: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value, - allGatherMessage: String - ): RequestToSync = { - requestMethod match { - case RequestMethod.BARRIER => - BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch, partitionId, requestMethod) - case RequestMethod.ALL_GATHER => - AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch, partitionId, requestMethod, allGatherMessage) - } - } - - private def runBarrier( - requestMethod: RequestMethod.Value, - allGatherMessage: String = "" - ): String = { - + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of '''misuses''' are listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit = { logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") logTrace("Current callSite: " + Utils.getCallSite()) @@ -108,12 +118,10 @@ class BarrierTaskContext private[spark] ( // Log the update of global sync every 60 seconds. timer.schedule(timerTask, 60000, 60000) - var json: String = "" - try { - val abortableRpcFuture = barrierCoordinator.askAbortable[String]( - message = getRequestToSync(numTasks, stageId, stageAttemptNumber, - taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage), + val abortableRpcFuture = barrierCoordinator.askAbortable[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. timeout = new RpcTimeout(365.days, "barrierTimeout")) @@ -125,7 +133,7 @@ class BarrierTaskContext private[spark] ( while (!abortableRpcFuture.toFuture.isCompleted) { // wait RPC future for at most 1 second try { - json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) + ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) } catch { case _: TimeoutException | _: InterruptedException => // If `TimeoutException` thrown, waiting RPC future reach 1 second. @@ -155,73 +163,6 @@ class BarrierTaskContext private[spark] ( timerTask.cancel() timer.purge() } - json - } - - /** - * :: Experimental :: - * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to - * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same - * stage have reached this routine. - * - * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all - * possible code branches. Otherwise, you may get the job hanging or a SparkException after - * timeout. Some examples of '''misuses''' are listed below: - * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it - * shall lead to timeout of the function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * if (context.partitionId() == 0) { - * // Do nothing. - * } else { - * context.barrier() - * } - * iter - * } - * }}} - * - * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the - * second function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * try { - * // Do something that might throw an Exception. - * doSomething() - * context.barrier() - * } catch { - * case e: Exception => logWarning("...", e) - * } - * context.barrier() - * iter - * } - * }}} - */ - @Experimental - @Since("2.4.0") - def barrier(): Unit = { - runBarrier(RequestMethod.BARRIER) - () - } - - /** - * :: Experimental :: - * Blocks until all tasks in the same stage have reached this routine. Each task passes in - * a message and returns with a list of all the messages passed in by each of those tasks. - * - * CAUTION! The allGather method requires the same precautions as the barrier method - * - * The message is type String rather than Array[Byte] because it is more convenient for - * the user at the cost of worse performance. - */ - @Experimental - @Since("3.0.0") - def allGather(message: String): ArrayBuffer[String] = { - val json = runBarrier(RequestMethod.ALL_GATHER, message) - val jsonArray = parse(json) - implicit val formats = DefaultFormats - ArrayBuffer(jsonArray.extract[Array[String]]: _*) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index fa8bf0fc06358..658e0d593a167 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,13 +24,8 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} - import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} @@ -243,18 +238,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock.setSoTimeout(10000) authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) - val requestMethod = input.readInt() - // The BarrierTaskContext function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(0) - requestMethod match { + input.readInt() match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - barrierAndServe(requestMethod, sock) - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val length = input.readInt() - val message = new Array[Byte](length) - input.readFully(message) - barrierAndServe(requestMethod, sock, new String(message, UTF_8)) + // The barrier() function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + barrierAndServe(sock) + case _ => val out = new DataOutputStream(new BufferedOutputStream( sock.getOutputStream)) @@ -405,31 +395,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } /** - * Gateway to call BarrierTaskContext methods. + * Gateway to call BarrierTaskContext.barrier(). */ - def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { - require( - serverSocket.isDefined, - "No available ServerSocket to redirect the BarrierTaskContext method call." - ) + def barrierAndServe(sock: Socket): Unit = { + require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { - var result: String = "" - requestMethod match { - case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - context.asInstanceOf[BarrierTaskContext].barrier() - result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather( - message - ) - result = compact(render(JArray( - messages.map( - (message) => JString(message) - ).toList - ))) - } - writeUTF(result, out) + context.asInstanceOf[BarrierTaskContext].barrier() + writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) } catch { case e: SparkException => writeUTF(e.getMessage, out) @@ -664,7 +638,6 @@ private[spark] object SpecialLengths { private[spark] object BarrierTaskContextMessageProtocol { val BARRIER_FUNCTION = 1 - val ALL_GATHER_FUNCTION = 2 val BARRIER_RESULT_SUCCESS = "success" val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index ed38b7f0ecac1..fc8ac38479932 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import java.io.File -import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ @@ -53,79 +52,6 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { assert(times.max - times.min <= 1000) } - test("share messages with allGather() call") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) - val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { it => - val context = BarrierTaskContext.get() - // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) - // Pass partitionId message in - val message = context.partitionId().toString - val messages = context.allGather(message) - messages.toList.iterator - } - // Take a sorted list of all the partitionId messages - val messages = rdd2.collect().head - // All the task partitionIds are shared - for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString) - } - - test("throw exception if we attempt to synchronize with different blocking calls") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) - val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { it => - val context = BarrierTaskContext.get() - val partitionId = context.partitionId - if (partitionId == 0) { - context.barrier() - } else { - context.allGather(partitionId.toString) - } - Seq(null).iterator - } - val error = intercept[SparkException] { - rdd2.collect() - }.getMessage - assert(error.contains("does not match the current synchronized requestMethod")) - } - - test("successively sync with allGather and barrier") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) - val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { it => - val context = BarrierTaskContext.get() - // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) - context.barrier() - val time1 = System.currentTimeMillis() - // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) - // Pass partitionId message in - val message = context.partitionId().toString - val messages = context.allGather(message) - val time2 = System.currentTimeMillis() - Seq((time1, time2)).iterator - } - val times = rdd2.collect() - // All the tasks shall finish the first round of global sync within a short time slot. - val times1 = times.map(_._1) - assert(times1.max - times1.min <= 1000) - - // All the tasks shall finish the second round of global sync within a short time slot. - val times2 = times.map(_._2) - assert(times2.max - times2.min <= 1000) - } - test("support multiple barrier() call within a single task") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 90bd2345ac525..d648f63338514 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,10 +16,9 @@ # from __future__ import print_function -import json from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import write_int, write_with_length, UTF8Deserializer +from pyspark.serializers import write_int, UTF8Deserializer class TaskContext(object): @@ -108,28 +107,18 @@ def resources(self): BARRIER_FUNCTION = 1 -ALL_GATHER_FUNCTION = 2 -def _load_from_socket(port, auth_secret, function, all_gather_message=None): +def _load_from_socket(port, auth_secret): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) - - # The call may block forever, so no timeout + # The barrier() call may block forever, so no timeout sock.settimeout(None) - - if function == BARRIER_FUNCTION: - # Make a barrier() function call. - write_int(function, sockfile) - elif function == ALL_GATHER_FUNCTION: - # Make a all_gather() function call. - write_int(function, sockfile) - write_with_length(all_gather_message.encode("utf-8"), sockfile) - else: - raise ValueError("Unrecognized function type") + # Make a barrier() function call. + write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() # Collect result. @@ -210,33 +199,7 @@ def barrier(self): raise Exception("Not supported to call barrier() before initialize " + "BarrierTaskContext.") else: - _load_from_socket(self._port, self._secret, BARRIER_FUNCTION) - - def allGather(self, message=""): - """ - .. note:: Experimental - - This function blocks until all tasks in the same stage have reached this routine. - Each task passes in a message and returns with a list of all the messages passed in - by each of those tasks. - - .. warning:: In a barrier stage, each task much have the same number of `allGather()` - calls, in all possible code branches. - Otherwise, you may get the job hanging or a SparkException after timeout. - """ - if not isinstance(message, str): - raise ValueError("Argument `message` must be of type `str`") - elif self._port is None or self._secret is None: - raise Exception("Not supported to call barrier() before initialize " + - "BarrierTaskContext.") - else: - gathered_items = _load_from_socket( - self._port, - self._secret, - ALL_GATHER_FUNCTION, - message, - ) - return [e for e in json.loads(gathered_items)] + _load_from_socket(self._port, self._secret) def getTaskInfos(self): """ diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 0053aadd9c3ed..68cfe814762e0 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -135,26 +135,6 @@ def context_barrier(x): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) - def test_all_gather(self): - """ - Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks - within a stage and passes messages properly. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - def context_barrier(x): - tc = BarrierTaskContext.get() - time.sleep(random.randint(1, 10)) - out = tc.allGather(str(context.partitionId())) - pids = [int(e) for e in out] - return [pids] - - pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0] - self.assertTrue(pids == [0, 1, 2, 3]) - def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the From 25db8c71a2100c167b8c2d7a6c540ebc61db9b73 Mon Sep 17 00:00:00 2001 From: David Toneian Date: Fri, 14 Feb 2020 11:00:35 +0900 Subject: [PATCH 043/185] [PYSPARK][DOCS][MINOR] Changed `:func:` to `:attr:` Sphinx roles, fixed links in documentation of `Data{Frame,Stream}{Reader,Writer}` This commit is published into the public domain. ### What changes were proposed in this pull request? This PR fixes the documentation of `DataFrameReader`, `DataFrameWriter`, `DataStreamReader`, and `DataStreamWriter`, where attributes of other classes were misrepresented as functions. Additionally, creation of hyperlinks across modules was fixed in these instances. ### Why are the changes needed? The old state produced documentation that suggested invalid usage of PySpark objects (accessing attributes as though they were callable.) ### Does this PR introduce any user-facing change? No, except for improved documentation. ### How was this patch tested? No test added; documentation build runs through. Closes #27553 from DavidToneian/docfix-DataFrameReader-DataFrameWriter-DataStreamReader-DataStreamWriter. Authored-by: David Toneian Signed-off-by: HyukjinKwon --- python/pyspark/sql/readwriter.py | 4 ++-- python/pyspark/sql/streaming.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 3d3280dbd9943..69660395ad823 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -48,7 +48,7 @@ def _set_opts(self, schema=None, **options): class DataFrameReader(OptionUtils): """ Interface used to load a :class:`DataFrame` from external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`spark.read` + (e.g. file systems, key-value stores, etc). Use :attr:`SparkSession.read` to access this. .. versionadded:: 1.4 @@ -616,7 +616,7 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar class DataFrameWriter(OptionUtils): """ Interface used to write a :class:`DataFrame` to external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write` + (e.g. file systems, key-value stores, etc). Use :attr:`DataFrame.write` to access this. .. versionadded:: 1.4 diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index f17a52f6b3dc8..5fced8aca9bdf 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -276,9 +276,9 @@ def resetTerminated(self): class DataStreamReader(OptionUtils): """ - Interface used to load a streaming :class:`DataFrame` from external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream` - to access this. + Interface used to load a streaming :class:`DataFrame ` from external + storage systems (e.g. file systems, key-value stores, etc). + Use :attr:`SparkSession.readStream ` to access this. .. note:: Evolving. @@ -750,8 +750,9 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non class DataStreamWriter(object): """ - Interface used to write a streaming :class:`DataFrame` to external storage systems - (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.writeStream` + Interface used to write a streaming :class:`DataFrame ` to external + storage systems (e.g. file systems, key-value stores, etc). + Use :attr:`DataFrame.writeStream ` to access this. .. note:: Evolving. From 0aed77a0155b404e39bc5dbc0579e29e4c7bf887 Mon Sep 17 00:00:00 2001 From: maryannxue Date: Fri, 14 Feb 2020 11:20:55 +0800 Subject: [PATCH 044/185] [SPARK-30801][SQL] Subqueries should not be AQE-ed if main query is not ### What changes were proposed in this pull request? This PR makes sure AQE is either enabled or disabled for the entire query, including the main query and all subqueries. Currently there are unsupported queries by AQE, e.g., queries that contain DPP filters. We need to make sure that if the main query is unsupported, none of the sub-queries should apply AQE, otherwise it can lead to performance regressions due to missed opportunity of sub-query reuse. ### Why are the changes needed? To get rid of potential perf regressions when AQE is turned on. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Updated DynamicPartitionPruningSuite: 1. Removed the existing workaround `withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")` 2. Added `DynamicPartitionPruningSuiteAEOn` and `DynamicPartitionPruningSuiteAEOff` to enable testing this suite with AQE on and off options 3. Added a check in `checkPartitionPruningPredicate` to verify that the subqueries are always in sync with the main query in terms of whether AQE is applied. Closes #27554 from maryannxue/spark-30801. Authored-by: maryannxue Signed-off-by: Wenchen Fan --- .../spark/sql/execution/QueryExecution.scala | 19 ++++++++++--- .../sql/DynamicPartitionPruningSuite.scala | 27 ++++++++++++++++--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 53b6b5d82c021..9109c05e75853 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -274,13 +274,25 @@ object QueryExecution { * are correct, insert whole stage code gen, and try to reduce the work done by reusing exchanges * and subqueries. */ - private[execution] def preparations(sparkSession: SparkSession): Seq[Rule[SparkPlan]] = + private[execution] def preparations(sparkSession: SparkSession): Seq[Rule[SparkPlan]] = { + + val sparkSessionWithAdaptiveExecutionOff = + if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) { + val session = sparkSession.cloneSession() + session.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false) + session + } else { + sparkSession + } + Seq( // `AdaptiveSparkPlanExec` is a leaf node. If inserted, all the following rules will be no-op // as the original plan is hidden behind `AdaptiveSparkPlanExec`. InsertAdaptiveSparkPlan(AdaptiveExecutionContext(sparkSession)), - PlanDynamicPruningFilters(sparkSession), - PlanSubqueries(sparkSession), + // If the following rules apply, it means the main query is not AQE-ed, so we make sure the + // subqueries are not AQE-ed either. + PlanDynamicPruningFilters(sparkSessionWithAdaptiveExecutionOff), + PlanSubqueries(sparkSessionWithAdaptiveExecutionOff), EnsureRequirements(sparkSession.sessionState.conf), ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf, sparkSession.sessionState.columnarRules), @@ -288,6 +300,7 @@ object QueryExecution { ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf) ) + } /** * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala index f7b51d6f4c8ef..baa9f5ecafc68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.GivenWhenThen import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression} import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec} import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} @@ -33,7 +33,7 @@ import org.apache.spark.sql.test.SharedSparkSession /** * Test suite for the filtering ratio policy used to trigger dynamic partition pruning (DPP). */ -class DynamicPartitionPruningSuite +abstract class DynamicPartitionPruningSuiteBase extends QueryTest with SharedSparkSession with GivenWhenThen @@ -43,9 +43,14 @@ class DynamicPartitionPruningSuite import testImplicits._ + val adaptiveExecutionOn: Boolean + override def beforeAll(): Unit = { super.beforeAll() + spark.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, adaptiveExecutionOn) + spark.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY, true) + val factData = Seq[(Int, Int, Int, Int)]( (1000, 1, 1, 10), (1010, 2, 1, 10), @@ -153,6 +158,8 @@ class DynamicPartitionPruningSuite sql("DROP TABLE IF EXISTS fact_stats") sql("DROP TABLE IF EXISTS dim_stats") } finally { + spark.sessionState.conf.unsetConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED) + spark.sessionState.conf.unsetConf(SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY) super.afterAll() } } @@ -195,6 +202,11 @@ class DynamicPartitionPruningSuite fail(s"Invalid child node found in\n$s") } } + + val isMainQueryAdaptive = plan.isInstanceOf[AdaptiveSparkPlanExec] + subqueriesAll(plan).filterNot(subqueryBroadcast.contains).foreach { s => + assert(s.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined == isMainQueryAdaptive) + } } /** @@ -1173,8 +1185,7 @@ class DynamicPartitionPruningSuite } test("join key with multiple references on the filtering plan") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { // when enable AQE, the reusedExchange is inserted when executed. withTable("fact", "dim") { spark.range(100).select( @@ -1270,3 +1281,11 @@ class DynamicPartitionPruningSuite } } } + +class DynamicPartitionPruningSuiteAEOff extends DynamicPartitionPruningSuiteBase { + override val adaptiveExecutionOn: Boolean = false +} + +class DynamicPartitionPruningSuiteAEOn extends DynamicPartitionPruningSuiteBase { + override val adaptiveExecutionOn: Boolean = true +} From b2134ee73cfad4d78aaf8f0a2898011ac0308e48 Mon Sep 17 00:00:00 2001 From: David Toneian Date: Fri, 14 Feb 2020 13:49:11 +0900 Subject: [PATCH 045/185] [SPARK-30823][PYTHON][DOCS] Set `%PYTHONPATH%` when building PySpark documentation on Windows MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit is published into the public domain. ### What changes were proposed in this pull request? In analogy to `python/docs/Makefile`, which has > export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.8.1-src.zip) on line 10, this PR adds > set PYTHONPATH=..;..\lib\py4j-0.10.8.1-src.zip to `make2.bat`. Since there is no `realpath` in default installations of Windows, I left the relative paths unresolved. Per the instructions on how to build docs, `make.bat` is supposed to be run from `python/docs` as the working directory, so this should probably not cause issues (`%BUILDDIR%` is a relative path as well.) ### Why are the changes needed? When building the PySpark documentation on Windows, by changing directory to `python/docs` and running `make.bat` (which runs `make2.bat`), the majority of the documentation may not be built if pyspark is not in the default `%PYTHONPATH%`. Sphinx then reports that `pyspark` (and possibly dependencies) cannot be imported. If `pyspark` is in the default `%PYTHONPATH%`, I suppose it is that version of `pyspark` – as opposed to the version found above the `python/docs` directory – that is considered when building the documentation, which may result in documentation that does not correspond to the development version one is trying to build. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Manual tests on my Windows 10 machine. Additional tests with other environments very welcome! Closes #27569 from DavidToneian/SPARK-30823. Authored-by: David Toneian Signed-off-by: HyukjinKwon --- python/docs/make2.bat | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/docs/make2.bat b/python/docs/make2.bat index 05d22eb5cdd23..742df373166da 100644 --- a/python/docs/make2.bat +++ b/python/docs/make2.bat @@ -2,6 +2,8 @@ REM Command file for Sphinx documentation +set PYTHONPATH=..;..\lib\py4j-0.10.8.1-src.zip + if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build From 99b8136a86030411e6bcbd312f40eb2a901ab0f0 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Fri, 14 Feb 2020 16:52:28 +0800 Subject: [PATCH 046/185] [SPARK-25990][SQL] ScriptTransformation should handle different data types correctly ### What changes were proposed in this pull request? We should convert Spark InternalRows to hive data via `HiveInspectors.wrapperFor`. ### Why are the changes needed? We may hit below exception without this change: ``` [info] org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 1 times, most recent failure: Lost task 0.0 in stage 1.0 (TID 1, 192.168.1.6, executor driver): java.lang.ClassCastException: org.apache.spark.sql.types.Decimal cannot be cast to org.apache.hadoop.hive.common.type.HiveDecimal [info] at org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector.getPrimitiveJavaObject(JavaHiveDecimalObjectInspector.java:55) [info] at org.apache.hadoop.hive.serde2.lazy.LazyUtils.writePrimitiveUTF8(LazyUtils.java:321) [info] at org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe.serialize(LazySimpleSerDe.java:292) [info] at org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe.serializeField(LazySimpleSerDe.java:247) [info] at org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe.doSerialize(LazySimpleSerDe.java:231) [info] at org.apache.hadoop.hive.serde2.AbstractEncodingAwareSerDe.serialize(AbstractEncodingAwareSerDe.java:55) [info] at org.apache.spark.sql.hive.execution.ScriptTransformationWriterThread.$anonfun$run$2(ScriptTransformationExec.scala:300) [info] at org.apache.spark.sql.hive.execution.ScriptTransformationWriterThread.$anonfun$run$2$adapted(ScriptTransformationExec.scala:281) [info] at scala.collection.Iterator.foreach(Iterator.scala:941) [info] at scala.collection.Iterator.foreach$(Iterator.scala:941) [info] at scala.collection.AbstractIterator.foreach(Iterator.scala:1429) [info] at org.apache.spark.sql.hive.execution.ScriptTransformationWriterThread.$anonfun$run$1(ScriptTransformationExec.scala:281) [info] at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23) [info] at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1932) [info] at org.apache.spark.sql.hive.execution.ScriptTransformationWriterThread.run(ScriptTransformationExec.scala:270) ``` ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Added new test. But please note that this test returns different result between Hive1.2 and Hive2.3 due to `HiveDecimal` or `SerDe` difference(don't know the root cause yet). Closes #27556 from Ngone51/script_transform. Lead-authored-by: yi.wu Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../execution/ScriptTransformationExec.scala | 32 ++++++++----- sql/hive/src/test/resources/test_script.py | 21 +++++++++ .../execution/ScriptTransformationSuite.scala | 46 ++++++++++++++++++- 3 files changed, 85 insertions(+), 14 deletions(-) create mode 100644 sql/hive/src/test/resources/test_script.py diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala index e12f663304e7a..40f7b4e8db7c5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala @@ -94,9 +94,8 @@ case class ScriptTransformationExec( // This new thread will consume the ScriptTransformation's input rows and write them to the // external process. That process's output will be read by this current thread. val writerThread = new ScriptTransformationWriterThread( - inputIterator, + inputIterator.map(outputProjection), input.map(_.dataType), - outputProjection, inputSerde, inputSoi, ioschema, @@ -249,16 +248,15 @@ case class ScriptTransformationExec( private class ScriptTransformationWriterThread( iter: Iterator[InternalRow], inputSchema: Seq[DataType], - outputProjection: Projection, @Nullable inputSerde: AbstractSerDe, - @Nullable inputSoi: ObjectInspector, + @Nullable inputSoi: StructObjectInspector, ioschema: HiveScriptIOSchema, outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, taskContext: TaskContext, conf: Configuration - ) extends Thread("Thread-ScriptTransformation-Feed") with Logging { + ) extends Thread("Thread-ScriptTransformation-Feed") with HiveInspectors with Logging { setDaemon(true) @@ -278,8 +276,8 @@ private class ScriptTransformationWriterThread( var threwException: Boolean = true val len = inputSchema.length try { - iter.map(outputProjection).foreach { row => - if (inputSerde == null) { + if (inputSerde == null) { + iter.foreach { row => val data = if (len == 0) { ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") } else { @@ -295,10 +293,21 @@ private class ScriptTransformationWriterThread( sb.toString() } outputStream.write(data.getBytes(StandardCharsets.UTF_8)) - } else { - val writable = inputSerde.serialize( - row.asInstanceOf[GenericInternalRow].values, inputSoi) + } + } else { + // Convert Spark InternalRows to hive data via `HiveInspectors.wrapperFor`. + val hiveData = new Array[Any](inputSchema.length) + val fieldOIs = inputSoi.getAllStructFieldRefs.asScala.map(_.getFieldObjectInspector).toArray + val wrappers = fieldOIs.zip(inputSchema).map { case (f, dt) => wrapperFor(f, dt) } + + iter.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + hiveData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, inputSchema(i))) + i += 1 + } + val writable = inputSerde.serialize(hiveData, inputSoi) if (scriptInputWriter != null) { scriptInputWriter.write(writable) } else { @@ -374,14 +383,13 @@ case class HiveScriptIOSchema ( val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = { + def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { inputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(input) val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) val fieldObjectInspectors = columnTypes.map(toInspector) val objectInspector = ObjectInspectorFactory .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) - .asInstanceOf[ObjectInspector] (serde, objectInspector) } } diff --git a/sql/hive/src/test/resources/test_script.py b/sql/hive/src/test/resources/test_script.py new file mode 100644 index 0000000000000..82ef7b38f0c1b --- /dev/null +++ b/sql/hive/src/test/resources/test_script.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +for line in sys.stdin: + (a, b, c, d, e) = line.split('\t') + sys.stdout.write('\t'.join([a, b, c, d, e])) + sys.stdout.flush() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 80a50c18bcb93..7d01fc53a4099 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.sql.Timestamp + import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.scalatest.Assertions._ import org.scalatest.BeforeAndAfterEach @@ -24,15 +26,18 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TaskContext, TestUtils} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StringType -class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton with - BeforeAndAfterEach { +class ScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton + with BeforeAndAfterEach { import spark.implicits._ private val noSerdeIOSchema = HiveScriptIOSchema( @@ -186,6 +191,43 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton wit rowsDf.select("name").collect()) assert(uncaughtExceptionHandler.exception.isEmpty) } + + test("SPARK-25990: TRANSFORM should handle different data types correctly") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = getTestResourcePath("test_script.py") + + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + val query = sql( + s""" + |SELECT + |TRANSFORM(a, b, c, d, e) + |USING 'python $scriptFilePath' AS (a, b, c, d, e) + |FROM v + """.stripMargin) + + // In Hive1.2, it does not do well on Decimal conversion. For example, in this case, + // it converts a decimal value's type from Decimal(38, 18) to Decimal(1, 0). So we need + // do extra cast here for Hive1.2. But in Hive2.3, it still keeps the original Decimal type. + val decimalToString: Column => Column = if (HiveUtils.isHive23) { + c => c.cast("string") + } else { + c => c.cast("decimal(1, 0)").cast("string") + } + checkAnswer(query, identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + decimalToString('d), + 'e.cast("string")).collect()) + } + } } private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { From 2a270a731a3b1da9a0fb036d648dd522e5c4d5ad Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Fri, 14 Feb 2020 18:20:18 +0800 Subject: [PATCH 047/185] [SPARK-30810][SQL] Parses and convert a CSV Dataset having different column from 'value' in csv(dataset) API ### What changes were proposed in this pull request? This PR fixes `DataFrameReader.csv(dataset: Dataset[String])` API to take a `Dataset[String]` originated from a column name different from `value`. This is a long-standing bug started from the very first place. `CSVUtils.filterCommentAndEmpty` assumed the `Dataset[String]` to be originated with `value` column. This PR changes to use the first column name in the schema. ### Why are the changes needed? For `DataFrameReader.csv(dataset: Dataset[String])` to support any `Dataset[String]` as the signature indicates. ### Does this PR introduce any user-facing change? Yes, ```scala val ds = spark.range(2).selectExpr("concat('a,b,', id) AS text").as[String] spark.read.option("header", true).option("inferSchema", true).csv(ds).show() ``` Before: ``` org.apache.spark.sql.AnalysisException: cannot resolve '`value`' given input columns: [text];; 'Filter (length(trim('value, None)) > 0) +- Project [concat(a,b,, cast(id#0L as string)) AS text#2] +- Range (0, 2, step=1, splits=Some(2)) ``` After: ``` +---+---+---+ | a| b| 0| +---+---+---+ | a| b| 1| +---+---+---+ ``` ### How was this patch tested? Unittest was added. Closes #27561 from HyukjinKwon/SPARK-30810. Authored-by: HyukjinKwon Signed-off-by: Wenchen Fan --- .../spark/sql/execution/datasources/csv/CSVUtils.scala | 7 ++++--- .../spark/sql/execution/datasources/csv/CSVSuite.scala | 7 +++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 21fabac472f4b..d8b52c503ad34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -33,11 +33,12 @@ object CSVUtils { // with the one below, `filterCommentAndEmpty` but execution path is different. One of them // might have to be removed in the near future if possible. import lines.sqlContext.implicits._ - val nonEmptyLines = lines.filter(length(trim($"value")) > 0) + val aliased = lines.toDF("value") + val nonEmptyLines = aliased.filter(length(trim($"value")) > 0) if (options.isCommentSet) { - nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)) + nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).as[String] } else { - nonEmptyLines + nonEmptyLines.as[String] } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index b1105b4a63bba..0be0e1e3da3dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -2294,6 +2294,13 @@ abstract class CSVSuite extends QueryTest with SharedSparkSession with TestCsvDa } } } + + test("SPARK-30810: parses and convert a CSV Dataset having different column from 'value'") { + val ds = spark.range(2).selectExpr("concat('a,b,', id) AS `a.text`").as[String] + val csv = spark.read.option("header", true).option("inferSchema", true).csv(ds) + assert(csv.schema.fieldNames === Seq("a", "b", "0")) + checkAnswer(csv, Row("a", "b", 1)) + } } class CSVv1Suite extends CSVSuite { From 7137a6d065edeaab97bf5bf49ffaca3d060a14fe Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 14 Feb 2020 22:16:57 +0800 Subject: [PATCH 048/185] [SPARK-30766][SQL] Fix the timestamp truncation to the `HOUR` and `DAY` levels ### What changes were proposed in this pull request? In the PR, I propose to use Java 8 time API in timestamp truncations to the levels of `HOUR` and `DAY`. The problem is in the usage of `timeZone.getOffset(millis)` in days/hours truncations where the combined calendar (Julian + Gregorian) is used underneath. ### Why are the changes needed? The change fix wrong truncations. For example, the following truncation to hours should print `0010-01-01 01:00:00` but it outputs wrong timestamp: ```scala Seq("0010-01-01 01:02:03.123456").toDF() .select($"value".cast("timestamp").as("ts")) .select(date_trunc("HOUR", $"ts").cast("string")) .show(false) +------------------------------------+ |CAST(date_trunc(HOUR, ts) AS STRING)| +------------------------------------+ |0010-01-01 01:30:17 | +------------------------------------+ ``` ### Does this PR introduce any user-facing change? Yes. After the changes, the result of the example above is: ```scala +------------------------------------+ |CAST(date_trunc(HOUR, ts) AS STRING)| +------------------------------------+ |0010-01-01 01:00:00 | +------------------------------------+ ``` ### How was this patch tested? - Added new test to `DateFunctionsSuite` - By `DateExpressionsSuite` and `DateTimeUtilsSuite` Closes #27512 from MaxGekk/fix-trunc-old-timestamp. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../expressions/datetimeExpressions.scala | 6 +-- .../sql/catalyst/util/DateTimeUtils.scala | 44 ++++++++++--------- .../catalyst/util/DateTimeUtilsSuite.scala | 39 ++++++++-------- .../apache/spark/sql/DateFunctionsSuite.scala | 13 ++++++ 4 files changed, 59 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index cf91489d8e6b7..adf7251256041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1690,15 +1690,15 @@ case class TruncTimestamp( override def eval(input: InternalRow): Any = { evalHelper(input, minLevel = MIN_LEVEL_OF_TIMESTAMP_TRUNC) { (t: Any, level: Int) => - DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, timeZone) + DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, zoneId) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) + val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) codeGenHelper(ctx, ev, minLevel = MIN_LEVEL_OF_TIMESTAMP_TRUNC, true) { (date: String, fmt: String) => - s"truncTimestamp($date, $fmt, $tz);" + s"truncTimestamp($date, $fmt, $zid);" } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 5976bcbb52fd7..dcc7337116777 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -713,32 +713,34 @@ object DateTimeUtils { } } + private def truncToUnit(t: SQLTimestamp, zoneId: ZoneId, unit: ChronoUnit): SQLTimestamp = { + val truncated = microsToInstant(t).atZone(zoneId).truncatedTo(unit) + instantToMicros(truncated.toInstant) + } + /** * Returns the trunc date time from original date time and trunc level. * Trunc level should be generated using `parseTruncLevel()`, should be between 0 and 12. */ - def truncTimestamp(t: SQLTimestamp, level: Int, timeZone: TimeZone): SQLTimestamp = { - if (level == TRUNC_TO_MICROSECOND) return t - var millis = MICROSECONDS.toMillis(t) - val truncated = level match { - case TRUNC_TO_MILLISECOND => millis - case TRUNC_TO_SECOND => - millis - millis % MILLIS_PER_SECOND - case TRUNC_TO_MINUTE => - millis - millis % MILLIS_PER_MINUTE - case TRUNC_TO_HOUR => - val offset = timeZone.getOffset(millis) - millis += offset - millis - millis % MILLIS_PER_HOUR - offset - case TRUNC_TO_DAY => - val offset = timeZone.getOffset(millis) - millis += offset - millis - millis % MILLIS_PER_DAY - offset - case _ => // Try to truncate date levels - val dDays = millisToDays(millis, timeZone.toZoneId) - daysToMillis(truncDate(dDays, level), timeZone.toZoneId) + def truncTimestamp(t: SQLTimestamp, level: Int, zoneId: ZoneId): SQLTimestamp = { + level match { + case TRUNC_TO_MICROSECOND => t + case TRUNC_TO_HOUR => truncToUnit(t, zoneId, ChronoUnit.HOURS) + case TRUNC_TO_DAY => truncToUnit(t, zoneId, ChronoUnit.DAYS) + case _ => + val millis = MICROSECONDS.toMillis(t) + val truncated = level match { + case TRUNC_TO_MILLISECOND => millis + case TRUNC_TO_SECOND => + millis - millis % MILLIS_PER_SECOND + case TRUNC_TO_MINUTE => + millis - millis % MILLIS_PER_MINUTE + case _ => // Try to truncate date levels + val dDays = millisToDays(millis, zoneId) + daysToMillis(truncDate(dDays, level), zoneId) + } + truncated * MICROS_PER_MILLIS } - truncated * MICROS_PER_MILLIS } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index cd0594c775a47..ff4d8a2457922 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -499,9 +499,9 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { level: Int, expected: String, inputTS: SQLTimestamp, - timezone: TimeZone = DateTimeUtils.defaultTimeZone()): Unit = { + zoneId: ZoneId = defaultZoneId): Unit = { val truncated = - DateTimeUtils.truncTimestamp(inputTS, level, timezone) + DateTimeUtils.truncTimestamp(inputTS, level, zoneId) val expectedTS = toTimestamp(expected, defaultZoneId) assert(truncated === expectedTS.get) } @@ -539,6 +539,7 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { for (tz <- ALL_TIMEZONES) { withDefaultTimeZone(tz) { + val zid = tz.toZoneId val inputTS = DateTimeUtils.stringToTimestamp( UTF8String.fromString("2015-03-05T09:32:05.359"), defaultZoneId) val inputTS1 = DateTimeUtils.stringToTimestamp( @@ -552,23 +553,23 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { val inputTS5 = DateTimeUtils.stringToTimestamp( UTF8String.fromString("1999-03-29T01:02:03.456789"), defaultZoneId) - testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS1.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS2.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS3.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", inputTS4.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS1.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", inputTS2.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_DECADE, "1990-01-01", inputTS5.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_CENTURY, "1901-01-01", inputTS5.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_MILLENNIUM, "2001-01-01", inputTS.get, tz) + testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS1.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS2.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS3.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", inputTS4.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS1.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", inputTS2.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_DECADE, "1990-01-01", inputTS5.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_CENTURY, "1901-01-01", inputTS5.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_MILLENNIUM, "2001-01-01", inputTS.get, zid) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 41d53c959ef99..ba45b9f9b62df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -856,4 +856,17 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { TimeZone.setDefault(defaultTz) } } + + test("SPARK-30766: date_trunc of old timestamps to hours and days") { + def checkTrunc(level: String, expected: String): Unit = { + val df = Seq("0010-01-01 01:02:03.123456") + .toDF() + .select($"value".cast("timestamp").as("ts")) + .select(date_trunc(level, $"ts").cast("string")) + checkAnswer(df, Row(expected)) + } + + checkTrunc("HOUR", "0010-01-01 01:00:00") + checkTrunc("DAY", "0010-01-01 00:00:00") + } } From b343757b1bd5d0344b82f36aa4d65ed34f840606 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Fri, 14 Feb 2020 10:18:08 -0800 Subject: [PATCH 049/185] [SPARK-29748][DOCS][FOLLOW-UP] Add a note that the legacy environment variable to set in both executor and driver ### What changes were proposed in this pull request? This PR address the comment at https://github.com/apache/spark/pull/26496#discussion_r379194091 and improves the migration guide to explicitly note that the legacy environment variable to set in both executor and driver. ### Why are the changes needed? To clarify this env should be set both in driver and executors. ### Does this PR introduce any user-facing change? Nope. ### How was this patch tested? I checked it via md editor. Closes #27573 from HyukjinKwon/SPARK-29748. Authored-by: HyukjinKwon Signed-off-by: Shixiong Zhu --- docs/pyspark-migration-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/pyspark-migration-guide.md b/docs/pyspark-migration-guide.md index 8ea4fec75edf8..f7f20389aa694 100644 --- a/docs/pyspark-migration-guide.md +++ b/docs/pyspark-migration-guide.md @@ -87,7 +87,7 @@ Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide. - Since Spark 3.0, `Column.getItem` is fixed such that it does not call `Column.apply`. Consequently, if `Column` is used as an argument to `getItem`, the indexing operator should be used. For example, `map_col.getItem(col('id'))` should be replaced with `map_col[col('id')]`. - - As of Spark 3.0 `Row` field names are no longer sorted alphabetically when constructing with named arguments for Python versions 3.6 and above, and the order of fields will match that as entered. To enable sorted fields by default, as in Spark 2.4, set the environment variable `PYSPARK_ROW_FIELD_SORTING_ENABLED` to "true". For Python versions less than 3.6, the field names will be sorted alphabetically as the only option. + - As of Spark 3.0 `Row` field names are no longer sorted alphabetically when constructing with named arguments for Python versions 3.6 and above, and the order of fields will match that as entered. To enable sorted fields by default, as in Spark 2.4, set the environment variable `PYSPARK_ROW_FIELD_SORTING_ENABLED` to "true" for both executors and driver - this environment variable must be consistent on all executors and driver; otherwise, it may cause failures or incorrect answers. For Python versions less than 3.6, the field names will be sorted alphabetically as the only option. ## Upgrading from PySpark 2.3 to 2.4 From d273a2bb0fac452a97f5670edd69d3e452e3e57e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 14 Feb 2020 12:36:52 -0800 Subject: [PATCH 050/185] [SPARK-20628][CORE][K8S] Start to improve Spark decommissioning & preemption support This PR is based on an existing/previou PR - https://github.com/apache/spark/pull/19045 ### What changes were proposed in this pull request? This changes adds a decommissioning state that we can enter when the cloud provider/scheduler lets us know we aren't going to be removed immediately but instead will be removed soon. This concept fits nicely in K8s and also with spot-instances on AWS / preemptible instances all of which we can get a notice that our host is going away. For now we simply stop scheduling jobs, in the future we could perform some kind of migration of data during scale-down, or at least stop accepting new blocks to cache. There is a design document at https://docs.google.com/document/d/1xVO1b6KAwdUhjEJBolVPl9C6sLj7oOveErwDSYdT-pE/edit?usp=sharing ### Why are the changes needed? With more move to preemptible multi-tenancy, serverless environments, and spot-instances better handling of node scale down is required. ### Does this PR introduce any user-facing change? There is no API change, however an additional configuration flag is added to enable/disable this behaviour. ### How was this patch tested? New integration tests in the Spark K8s integration testing. Extension of the AppClientSuite to test decommissioning seperate from the K8s. Closes #26440 from holdenk/SPARK-20628-keep-track-of-nodes-which-are-going-to-be-shutdown-r4. Lead-authored-by: Holden Karau Co-authored-by: Holden Karau Signed-off-by: Holden Karau --- .../apache/spark/deploy/DeployMessage.scala | 11 ++ .../apache/spark/deploy/ExecutorState.scala | 8 +- .../deploy/client/StandaloneAppClient.scala | 2 + .../client/StandaloneAppClientListener.scala | 2 + .../apache/spark/deploy/master/Master.scala | 31 +++++ .../apache/spark/deploy/worker/Worker.scala | 26 ++++ .../CoarseGrainedExecutorBackend.scala | 39 +++++- .../org/apache/spark/executor/Executor.scala | 16 +++ .../apache/spark/internal/config/Worker.scala | 5 + .../main/scala/org/apache/spark/rdd/RDD.scala | 2 + .../spark/scheduler/ExecutorLossReason.scala | 8 ++ .../org/apache/spark/scheduler/Pool.scala | 4 + .../apache/spark/scheduler/Schedulable.scala | 1 + .../spark/scheduler/SchedulerBackend.scala | 3 + .../spark/scheduler/TaskScheduler.scala | 5 + .../spark/scheduler/TaskSchedulerImpl.scala | 5 + .../spark/scheduler/TaskSetManager.scala | 6 + .../cluster/CoarseGrainedClusterMessage.scala | 2 + .../CoarseGrainedSchedulerBackend.scala | 66 +++++++++- .../cluster/StandaloneSchedulerBackend.scala | 6 + .../org/apache/spark/util/SignalUtils.scala | 2 +- .../spark/deploy/client/AppClientSuite.scala | 39 +++++- .../spark/scheduler/DAGSchedulerSuite.scala | 2 + .../ExternalClusterManagerSuite.scala | 1 + .../scheduler/WorkerDecommissionSuite.scala | 84 +++++++++++++ .../spark/deploy/k8s/KubernetesConf.scala | 3 + .../features/BasicExecutorFeatureStep.scala | 20 ++- .../src/main/dockerfiles/spark/Dockerfile | 4 +- .../src/main/dockerfiles/spark/decom.sh | 35 ++++++ .../src/main/dockerfiles/spark/entrypoint.sh | 6 +- .../dev/dev-run-integration-tests.sh | 9 +- .../integrationtest/DecommissionSuite.scala | 49 ++++++++ .../k8s/integrationtest/KubernetesSuite.scala | 117 ++++++++++++++---- .../tests/decommissioning.py | 45 +++++++ sbin/decommission-slave.sh | 57 +++++++++ sbin/spark-daemon.sh | 15 +++ 36 files changed, 690 insertions(+), 46 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala create mode 100755 resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh create mode 100644 resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala create mode 100644 resource-managers/kubernetes/integration-tests/tests/decommissioning.py create mode 100644 sbin/decommission-slave.sh diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index fba371dcfb761..18305ad3746a6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -60,6 +60,15 @@ private[deploy] object DeployMessages { assert (port > 0) } + /** + * @param id the worker id + * @param worker the worker endpoint ref + */ + case class WorkerDecommission( + id: String, + worker: RpcEndpointRef) + extends DeployMessage + case class ExecutorStateChanged( appId: String, execId: Int, @@ -149,6 +158,8 @@ private[deploy] object DeployMessages { case object ReregisterWithMaster // used when a worker attempts to reconnect to a master + case object DecommissionSelf // Mark as decommissioned. May be Master to Worker in the future. + // AppClient to Master case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index 69c98e28931d7..0751bcf221f86 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -19,9 +19,13 @@ package org.apache.spark.deploy private[deploy] object ExecutorState extends Enumeration { - val LAUNCHING, RUNNING, KILLED, FAILED, LOST, EXITED = Value + val LAUNCHING, RUNNING, KILLED, FAILED, LOST, EXITED, DECOMMISSIONED = Value type ExecutorState = Value - def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST, EXITED).contains(state) + // DECOMMISSIONED isn't listed as finished since we don't want to remove the executor from + // the worker and the executor still exists - but we do want to avoid scheduling new tasks on it. + private val finishedStates = Seq(KILLED, FAILED, LOST, EXITED) + + def isFinished(state: ExecutorState): Boolean = finishedStates.contains(state) } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 8f17159228f8b..eedf5e969e291 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -180,6 +180,8 @@ private[spark] class StandaloneAppClient( logInfo("Executor updated: %s is now %s%s".format(fullId, state, messageText)) if (ExecutorState.isFinished(state)) { listener.executorRemoved(fullId, message.getOrElse(""), exitStatus, workerLost) + } else if (state == ExecutorState.DECOMMISSIONED) { + listener.executorDecommissioned(fullId, message.getOrElse("")) } case WorkerRemoved(id, host, message) => diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala index d8bc1a883def1..2e38a6847891d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClientListener.scala @@ -39,5 +39,7 @@ private[spark] trait StandaloneAppClientListener { def executorRemoved( fullId: String, message: String, exitStatus: Option[Int], workerLost: Boolean): Unit + def executorDecommissioned(fullId: String, message: String): Unit + def workerRemoved(workerId: String, host: String, message: String): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8d3795cae707a..71df5dfa423a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -243,6 +243,15 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) + case WorkerDecommission(id, workerRef) => + logInfo("Recording worker %s decommissioning".format(id)) + if (state == RecoveryState.STANDBY) { + workerRef.send(MasterInStandby) + } else { + // We use foreach since get gives us an option and we can skip the failures. + idToWorker.get(id).foreach(decommissionWorker) + } + case RegisterWorker( id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress, resources) => @@ -313,7 +322,9 @@ private[deploy] class Master( // Only retry certain number of times so we don't go into an infinite loop. // Important note: this code path is not exercised by tests, so be very careful when // changing this `if` condition. + // We also don't count failures from decommissioned workers since they are "expected." if (!normalExit + && oldState != ExecutorState.DECOMMISSIONED && appInfo.incrementRetryCount() >= maxExecutorRetries && maxExecutorRetries >= 0) { // < 0 disables this application-killing path val execs = appInfo.executors.values @@ -850,6 +861,26 @@ private[deploy] class Master( true } + private def decommissionWorker(worker: WorkerInfo): Unit = { + if (worker.state != WorkerState.DECOMMISSIONED) { + logInfo("Decommissioning worker %s on %s:%d".format(worker.id, worker.host, worker.port)) + worker.setState(WorkerState.DECOMMISSIONED) + for (exec <- worker.executors.values) { + logInfo("Telling app of decommission executors") + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.DECOMMISSIONED, + Some("worker decommissioned"), None, workerLost = false)) + exec.state = ExecutorState.DECOMMISSIONED + exec.application.removeExecutor(exec) + } + // On recovery do not add a decommissioned executor + persistenceEngine.removeWorker(worker) + } else { + logWarning("Skipping decommissioning worker %s on %s:%d as worker is already decommissioned". + format(worker.id, worker.host, worker.port)) + } + } + private def removeWorker(worker: WorkerInfo, msg: String): Unit = { logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 4be495ac4f13f..d988bcedb47f0 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -67,6 +67,14 @@ private[deploy] class Worker( Utils.checkHost(host) assert (port > 0) + // If worker decommissioning is enabled register a handler on PWR to shutdown. + if (conf.get(WORKER_DECOMMISSION_ENABLED)) { + logInfo("Registering SIGPWR handler to trigger decommissioning.") + SignalUtils.register("PWR")(decommissionSelf) + } else { + logInfo("Worker decommissioning not enabled, SIGPWR will result in exiting.") + } + // A scheduled executor used to send messages at the specified time. private val forwardMessageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") @@ -128,6 +136,7 @@ private[deploy] class Worker( private val workerUri = RpcEndpointAddress(rpcEnv.address, endpointName).toString private var registered = false private var connected = false + private var decommissioned = false private val workerId = generateWorkerId() private val sparkHome = if (sys.props.contains(IS_TESTING.key)) { @@ -549,6 +558,8 @@ private[deploy] class Worker( case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_, resources_) => if (masterUrl != activeMasterUrl) { logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.") + } else if (decommissioned) { + logWarning("Asked to launch an executor while decommissioned. Not launching executor.") } else { try { logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name)) @@ -672,6 +683,9 @@ private[deploy] class Worker( case ApplicationFinished(id) => finishedApps += id maybeCleanupApplication(id) + + case DecommissionSelf => + decommissionSelf() } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -771,6 +785,18 @@ private[deploy] class Worker( } } + private[deploy] def decommissionSelf(): Boolean = { + if (conf.get(WORKER_DECOMMISSION_ENABLED)) { + logDebug("Decommissioning self") + decommissioned = true + sendToMaster(WorkerDecommission(workerId, self)) + } else { + logWarning("Asked to decommission self, but decommissioning not enabled") + } + // Return true since can be called as a signal handler + true + } + private[worker] def handleDriverStateChanged(driverStateChanged: DriverStateChanged): Unit = { val driverId = driverStateChanged.driverId val exception = driverStateChanged.exception diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 25c5b9812fa1a..faf03a64ae8b2 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -43,7 +43,7 @@ import org.apache.spark.rpc._ import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, ThreadUtils, Utils} +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, SignalUtils, ThreadUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, @@ -64,6 +64,7 @@ private[spark] class CoarseGrainedExecutorBackend( private[this] val stopping = new AtomicBoolean(false) var executor: Executor = null + @volatile private var decommissioned = false @volatile var driver: Option[RpcEndpointRef] = None // If this CoarseGrainedExecutorBackend is changed to support multiple threads, then this may need @@ -80,6 +81,9 @@ private[spark] class CoarseGrainedExecutorBackend( private[executor] val taskResources = new mutable.HashMap[Long, Map[String, ResourceInformation]] override def onStart(): Unit = { + logInfo("Registering PWR handler.") + SignalUtils.register("PWR")(decommissionSelf) + logInfo("Connecting to driver: " + driverUrl) try { _resources = parseOrFindResources(resourcesFileOpt) @@ -160,6 +164,16 @@ private[spark] class CoarseGrainedExecutorBackend( if (executor == null) { exitExecutor(1, "Received LaunchTask command but executor was null") } else { + if (decommissioned) { + logError("Asked to launch a task while decommissioned.") + driver match { + case Some(endpoint) => + logInfo("Sending DecommissionExecutor to driver.") + endpoint.send(DecommissionExecutor(executorId)) + case _ => + logError("No registered driver to send Decommission to.") + } + } val taskDesc = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) taskResources(taskDesc.taskId) = taskDesc.resources @@ -242,6 +256,29 @@ private[spark] class CoarseGrainedExecutorBackend( System.exit(code) } + + private def decommissionSelf(): Boolean = { + logInfo("Decommissioning self w/sync") + try { + decommissioned = true + // Tell master we are are decommissioned so it stops trying to schedule us + if (driver.nonEmpty) { + driver.get.askSync[Boolean](DecommissionExecutor(executorId)) + } else { + logError("No driver to message decommissioning.") + } + if (executor != null) { + executor.decommission() + } + logInfo("Done decommissioning self.") + // Return true since we are handling a signal + true + } catch { + case e: Exception => + logError(s"Error ${e} during attempt to decommission self") + false + } + } } private[spark] object CoarseGrainedExecutorBackend extends Logging { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 8aeb16fe5d8c8..2bfa1cea4b26f 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -216,16 +216,32 @@ private[spark] class Executor( */ private var heartbeatFailures = 0 + /** + * Flag to prevent launching new tasks while decommissioned. There could be a race condition + * accessing this, but decommissioning is only intended to help not be a hard stop. + */ + private var decommissioned = false + heartbeater.start() metricsPoller.start() private[executor] def numRunningTasks: Int = runningTasks.size() + /** + * Mark an executor for decommissioning and avoid launching new tasks. + */ + private[spark] def decommission(): Unit = { + decommissioned = true + } + def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { val tr = new TaskRunner(context, taskDescription) runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) + if (decommissioned) { + log.error(s"Launching a task while in decommissioned state.") + } } def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/internal/config/Worker.scala b/core/src/main/scala/org/apache/spark/internal/config/Worker.scala index f1eaae29f18df..2b175c1e14ee5 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Worker.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Worker.scala @@ -71,4 +71,9 @@ private[spark] object Worker { ConfigBuilder("spark.worker.ui.compressedLogFileLengthCacheSize") .intConf .createWithDefault(100) + + private[spark] val WORKER_DECOMMISSION_ENABLED = + ConfigBuilder("spark.worker.decommission.enabled") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 64d2032a12721..a26b5791fa08b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -361,6 +361,7 @@ abstract class RDD[T: ClassTag]( readCachedBlock = false computeOrReadCheckpoint(partition, context) }) match { + // Block hit. case Left(blockResult) => if (readCachedBlock) { val existingMetrics = context.taskMetrics().inputMetrics @@ -374,6 +375,7 @@ abstract class RDD[T: ClassTag]( } else { new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]]) } + // Need to compute the block. case Right(iter) => new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]]) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 46a35b6a2eaf9..ee31093ec0652 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -58,3 +58,11 @@ private [spark] object LossReasonPending extends ExecutorLossReason("Pending los private[spark] case class SlaveLost(_message: String = "Slave lost", workerLost: Boolean = false) extends ExecutorLossReason(_message) + +/** + * A loss reason that means the executor is marked for decommissioning. + * + * This is used by the task scheduler to remove state associated with the executor, but + * not yet fail any tasks that were running in the executor before the executor is "fully" lost. + */ +private [spark] object ExecutorDecommission extends ExecutorLossReason("Executor decommission.") diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 80805df256a15..2e2851eb9070b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -88,6 +88,10 @@ private[spark] class Pool( schedulableQueue.asScala.foreach(_.executorLost(executorId, host, reason)) } + override def executorDecommission(executorId: String): Unit = { + schedulableQueue.asScala.foreach(_.executorDecommission(executorId)) + } + override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { var shouldRevive = false for (schedulable <- schedulableQueue.asScala) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index b6f88ed0a93aa..8cc239c81d11a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -43,6 +43,7 @@ private[spark] trait Schedulable { def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable def executorLost(executorId: String, host: String, reason: ExecutorLossReason): Unit + def executorDecommission(executorId: String): Unit def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean def getSortedTaskSetQueue: ArrayBuffer[TaskSetManager] } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 9159d2a0158d5..4752353046c19 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -27,6 +27,9 @@ private[spark] trait SchedulerBackend { def start(): Unit def stop(): Unit + /** + * Update the current offers and schedule tasks + */ def reviveOffers(): Unit def defaultParallelism(): Int diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 15f5d20e9be75..e9e638a3645ac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -98,6 +98,11 @@ private[spark] trait TaskScheduler { */ def applicationId(): String = appId + /** + * Process a decommissioning executor. + */ + def executorDecommission(executorId: String): Unit + /** * Process a lost executor */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bf92081d13907..1b197c4cca53e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -734,6 +734,11 @@ private[spark] class TaskSchedulerImpl( } } + override def executorDecommission(executorId: String): Unit = { + rootPool.executorDecommission(executorId) + backend.reviveOffers() + } + override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = { var failedExecutor: Option[String] = None diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 2ce11347ade39..18684ee8ebbc2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -1083,6 +1083,12 @@ private[spark] class TaskSetManager( levels.toArray } + def executorDecommission(execId: String): Unit = { + recomputeLocality() + // Future consideration: if an executor is decommissioned it may make sense to add the current + // tasks to the spec exec queue. + } + def recomputeLocality(): Unit = { // A zombie TaskSetManager may reach here while executorLost happens if (isZombie) return diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 283390814a6c0..8db0122f17ab4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -94,6 +94,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: ExecutorLossReason) extends CoarseGrainedClusterMessage + case class DecommissionExecutor(executorId: String) extends CoarseGrainedClusterMessage + case class RemoveWorker(workerId: String, host: String, message: String) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 63aa04986b073..6e1efdaf5beb2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -92,6 +92,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors that have been lost, but for which we don't yet know the real exit reason. private val executorsPendingLossReason = new HashSet[String] + // Executors which are being decommissioned + protected val executorsPendingDecommission = new HashSet[String] + // A map of ResourceProfile id to map of hostname with its possible task number running on it @GuardedBy("CoarseGrainedSchedulerBackend.this") protected var rpHostToLocalTaskCount: Map[Int, Map[String, Int]] = Map.empty @@ -185,11 +188,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor)) removeExecutor(executorId, reason) + case DecommissionExecutor(executorId) => + logError(s"Received decommission executor message ${executorId}.") + decommissionExecutor(executorId) + + case RemoveWorker(workerId, host, message) => + removeWorker(workerId, host, message) + case LaunchedExecutor(executorId) => executorDataMap.get(executorId).foreach { data => data.freeCores = data.totalCores } makeOffers(executorId) + case e => + logError(s"Received unexpected message. ${e}") } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -257,6 +269,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeWorker(workerId, host, message) context.reply(true) + case DecommissionExecutor(executorId) => + logError(s"Received decommission executor message ${executorId}.") + decommissionExecutor(executorId) + context.reply(true) + case RetrieveSparkAppConfig(resourceProfileId) => val rp = scheduler.sc.resourceProfileManager.resourceProfileFromId(resourceProfileId) val reply = SparkAppConfig( @@ -265,6 +282,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp Option(delegationTokens.get()), rp) context.reply(reply) + case e => + logError(s"Received unexpected ask ${e}") } // Make fake resource offers on all executors @@ -365,6 +384,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp addressToExecutorId -= executorInfo.executorAddress executorDataMap -= executorId executorsPendingLossReason -= executorId + executorsPendingDecommission -= executorId executorsPendingToRemove.remove(executorId).getOrElse(false) } totalCoreCount.addAndGet(-executorInfo.totalCores) @@ -389,6 +409,35 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp scheduler.workerRemoved(workerId, host, message) } + /** + * Mark a given executor as decommissioned and stop making resource offers for it. + */ + private def decommissionExecutor(executorId: String): Boolean = { + val shouldDisable = CoarseGrainedSchedulerBackend.this.synchronized { + // Only bother decommissioning executors which are alive. + if (isExecutorActive(executorId)) { + executorsPendingDecommission += executorId + true + } else { + false + } + } + + if (shouldDisable) { + logInfo(s"Starting decommissioning executor $executorId.") + try { + scheduler.executorDecommission(executorId) + } catch { + case e: Exception => + logError(s"Unexpected error during decommissioning ${e.toString}", e) + } + logInfo(s"Finished decommissioning executor $executorId.") + } else { + logInfo(s"Skipping decommissioning of executor $executorId.") + } + shouldDisable + } + /** * Stop making resource offers for the given executor. The executor is marked as lost with * the loss reason still pending. @@ -511,8 +560,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } protected def removeWorker(workerId: String, host: String, message: String): Unit = { - driverEndpoint.ask[Boolean](RemoveWorker(workerId, host, message)).failed.foreach(t => - logError(t.getMessage, t))(ThreadUtils.sameThread) + driverEndpoint.send(RemoveWorker(workerId, host, message)) + } + + /** + * Called by subclasses when notified of a decommissioning executor. + */ + private[spark] def decommissionExecutor(executorId: String): Unit = { + if (driverEndpoint != null) { + logInfo("Propegating executor decommission to driver.") + driverEndpoint.send(DecommissionExecutor(executorId)) + } } def sufficientResourcesRegistered(): Boolean = true @@ -543,7 +601,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def isExecutorActive(id: String): Boolean = synchronized { executorDataMap.contains(id) && !executorsPendingToRemove.contains(id) && - !executorsPendingLossReason.contains(id) + !executorsPendingLossReason.contains(id) && + !executorsPendingDecommission.contains(id) + } override def maxNumConcurrentTasks(): Int = synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index d91d78b29f98d..42c46464d79e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -174,6 +174,12 @@ private[spark] class StandaloneSchedulerBackend( removeExecutor(fullId.split("/")(1), reason) } + override def executorDecommissioned(fullId: String, message: String) { + logInfo("Asked to decommission executor") + decommissionExecutor(fullId.split("/")(1)) + logInfo("Executor %s decommissioned: %s".format(fullId, message)) + } + override def workerRemoved(workerId: String, host: String, message: String): Unit = { logInfo("Worker %s removed: %s".format(workerId, message)) removeWorker(workerId, host, message) diff --git a/core/src/main/scala/org/apache/spark/util/SignalUtils.scala b/core/src/main/scala/org/apache/spark/util/SignalUtils.scala index 5a24965170cef..230195da2a121 100644 --- a/core/src/main/scala/org/apache/spark/util/SignalUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/SignalUtils.scala @@ -60,7 +60,7 @@ private[spark] object SignalUtils extends Logging { if (SystemUtils.IS_OS_UNIX) { try { val handler = handlers.getOrElseUpdate(signal, { - logInfo("Registered signal handler for " + signal) + logInfo("Registering signal handler for " + signal) new ActionHandler(new Signal(signal)) }) handler.register(action) diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index a1d3077b8fc87..a3e39d7f53728 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.{ApplicationInfo, Master} import org.apache.spark.deploy.worker.Worker -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils @@ -44,13 +44,13 @@ class AppClientSuite with Eventually with ScalaFutures { private val numWorkers = 2 - private val conf = new SparkConf() - private val securityManager = new SecurityManager(conf) + private var conf: SparkConf = null private var masterRpcEnv: RpcEnv = null private var workerRpcEnvs: Seq[RpcEnv] = null private var master: Master = null private var workers: Seq[Worker] = null + private var securityManager: SecurityManager = null /** * Start the local cluster. @@ -58,6 +58,8 @@ class AppClientSuite */ override def beforeAll(): Unit = { super.beforeAll() + conf = new SparkConf().set(config.Worker.WORKER_DECOMMISSION_ENABLED.key, "true") + securityManager = new SecurityManager(conf) masterRpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityManager) workerRpcEnvs = (0 until numWorkers).map { i => RpcEnv.create(Worker.SYSTEM_NAME + i, "localhost", 0, conf, securityManager) @@ -111,8 +113,23 @@ class AppClientSuite assert(apps.head.getExecutorLimit === numExecutorsRequested, s"executor request failed") } + + // Save the executor id before decommissioning so we can kill it + val application = getApplications().head + val executors = application.executors + val executorId: String = executors.head._2.fullId + + // Send a decommission self to all the workers + // Note: normally the worker would send this on their own. + workers.foreach(worker => worker.decommissionSelf()) + + // Decommissioning is async. + eventually(timeout(1.seconds), interval(10.millis)) { + // We only record decommissioning for the executor we've requested + assert(ci.listener.execDecommissionedList.size === 1) + } + // Send request to kill executor, verify request was made - val executorId: String = getApplications().head.executors.head._2.fullId whenReady( ci.client.killExecutors(Seq(executorId)), timeout(10.seconds), @@ -120,6 +137,15 @@ class AppClientSuite assert(acknowledged) } + // Verify that asking for executors on the decommissioned workers fails + whenReady( + ci.client.requestTotalExecutors(numExecutorsRequested), + timeout(10.seconds), + interval(10.millis)) { acknowledged => + assert(acknowledged) + } + assert(getApplications().head.executors.size === 0) + // Issue stop command for Client to disconnect from Master ci.client.stop() @@ -189,6 +215,7 @@ class AppClientSuite val deadReasonList = new ConcurrentLinkedQueue[String]() val execAddedList = new ConcurrentLinkedQueue[String]() val execRemovedList = new ConcurrentLinkedQueue[String]() + val execDecommissionedList = new ConcurrentLinkedQueue[String]() def connected(id: String): Unit = { connectedIdList.add(id) @@ -218,6 +245,10 @@ class AppClientSuite execRemovedList.add(id) } + def executorDecommissioned(id: String, message: String): Unit = { + execDecommissionedList.add(id) + } + def workerRemoved(workerId: String, host: String, message: String): Unit = {} } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 101e60c73e9f8..e40b63fe13cb1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -167,6 +167,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 + override def executorDecommission(executorId: String) = {} override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None @@ -707,6 +708,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId, executorUpdates: Map[(Int, Int), ExecutorMetrics]): Boolean = true + override def executorDecommission(executorId: String): Unit = {} override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 4e71ec1ea7b37..9f593e0039adc 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -89,6 +89,7 @@ private class DummyTaskScheduler extends TaskScheduler { override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 + override def executorDecommission(executorId: String): Unit = {} override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None diff --git a/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala new file mode 100644 index 0000000000000..15733b0d932ec --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/WorkerDecommissionSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.concurrent.Semaphore + +import scala.concurrent.TimeoutException +import scala.concurrent.duration._ + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.scheduler.cluster.StandaloneSchedulerBackend +import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils} + +class WorkerDecommissionSuite extends SparkFunSuite with LocalSparkContext { + + override def beforeEach(): Unit = { + val conf = new SparkConf().setAppName("test").setMaster("local") + .set(config.Worker.WORKER_DECOMMISSION_ENABLED, true) + + sc = new SparkContext("local-cluster[2, 1, 1024]", "test", conf) + } + + test("verify task with no decommissioning works as expected") { + val input = sc.parallelize(1 to 10) + input.count() + val sleepyRdd = input.mapPartitions{ x => + Thread.sleep(100) + x + } + assert(sleepyRdd.count() === 10) + } + + test("verify a task with all workers decommissioned succeeds") { + val input = sc.parallelize(1 to 10) + // Do a count to wait for the executors to be registered. + input.count() + val sleepyRdd = input.mapPartitions{ x => + Thread.sleep(50) + x + } + // Listen for the job + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + sem.release() + } + }) + // Start the task. + val asyncCount = sleepyRdd.countAsync() + // Wait for the job to have started + sem.acquire(1) + // Decommission all the executors, this should not halt the current task. + // decom.sh message passing is tested manually. + val sched = sc.schedulerBackend.asInstanceOf[StandaloneSchedulerBackend] + val execs = sched.getExecutorIds() + execs.foreach(execId => sched.decommissionExecutor(execId)) + val asyncCountResult = ThreadUtils.awaitResult(asyncCount, 2.seconds) + assert(asyncCountResult === 10) + // Try and launch task after decommissioning, this should fail + val postDecommissioned = input.map(x => x) + val postDecomAsyncCount = postDecommissioned.countAsync() + val thrown = intercept[java.util.concurrent.TimeoutException]{ + val result = ThreadUtils.awaitResult(postDecomAsyncCount, 2.seconds) + } + assert(postDecomAsyncCount.isCompleted === false, + "After exec decommission new task could not launch") + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala index 09943b7974ed9..f42f3415baa15 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -55,6 +55,9 @@ private[spark] abstract class KubernetesConf(val sparkConf: SparkConf) { } } + def workerDecommissioning: Boolean = + sparkConf.get(org.apache.spark.internal.config.Worker.WORKER_DECOMMISSION_ENABLED) + def nodeSelector: Map[String, String] = KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala index 6a26df2997fd2..f575241de9540 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.k8s._ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Python._ import org.apache.spark.rpc.RpcEndpointAddress @@ -33,7 +34,7 @@ import org.apache.spark.util.Utils private[spark] class BasicExecutorFeatureStep( kubernetesConf: KubernetesExecutorConf, secMgr: SecurityManager) - extends KubernetesFeatureConfigStep { + extends KubernetesFeatureConfigStep with Logging { // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf private val executorContainerImage = kubernetesConf @@ -186,6 +187,21 @@ private[spark] class BasicExecutorFeatureStep( .endResources() .build() }.getOrElse(executorContainer) + val containerWithLifecycle = + if (!kubernetesConf.workerDecommissioning) { + logInfo("Decommissioning not enabled, skipping shutdown script") + containerWithLimitCores + } else { + logInfo("Adding decommission script to lifecycle") + new ContainerBuilder(containerWithLimitCores).withNewLifecycle() + .withNewPreStop() + .withNewExec() + .addToCommand("/opt/decom.sh") + .endExec() + .endPreStop() + .endLifecycle() + .build() + } val ownerReference = kubernetesConf.driverPod.map { pod => new OwnerReferenceBuilder() .withController(true) @@ -213,6 +229,6 @@ private[spark] class BasicExecutorFeatureStep( kubernetesConf.get(KUBERNETES_EXECUTOR_SCHEDULER_NAME) .foreach(executorPod.getSpec.setSchedulerName) - SparkPod(executorPod, containerWithLimitCores) + SparkPod(executorPod, containerWithLifecycle) } } diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 6ed37fc637b31..cc65a7da12eef 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -30,7 +30,7 @@ ARG spark_uid=185 RUN set -ex && \ apt-get update && \ ln -s /lib /lib64 && \ - apt install -y bash tini libc6 libpam-modules krb5-user libnss3 && \ + apt install -y bash tini libc6 libpam-modules krb5-user libnss3 procps && \ mkdir -p /opt/spark && \ mkdir -p /opt/spark/examples && \ mkdir -p /opt/spark/work-dir && \ @@ -45,6 +45,7 @@ COPY jars /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY kubernetes/dockerfiles/spark/entrypoint.sh /opt/ +COPY kubernetes/dockerfiles/spark/decom.sh /opt/ COPY examples /opt/spark/examples COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data @@ -53,6 +54,7 @@ ENV SPARK_HOME /opt/spark WORKDIR /opt/spark/work-dir RUN chmod g+w /opt/spark/work-dir +RUN chmod a+x /opt/decom.sh ENTRYPOINT [ "/opt/entrypoint.sh" ] diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh new file mode 100755 index 0000000000000..8a5208d49a70f --- /dev/null +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/decom.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +set -ex +echo "Asked to decommission" +# Find the pid to signal +date | tee -a ${LOG} +WORKER_PID=$(ps -o pid -C java | tail -n 1| awk '{ sub(/^[ \t]+/, ""); print }') +echo "Using worker pid $WORKER_PID" +kill -s SIGPWR ${WORKER_PID} +# For now we expect this to timeout, since we don't start exiting the backend. +echo "Waiting for worker pid to exit" +# If the worker does exit stop blocking the cleanup. +timeout 60 tail --pid=${WORKER_PID} -f /dev/null +date +echo "Done" +date +sleep 30 diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh index 6ee3523c8edab..05ab782caecae 100755 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/entrypoint.sh @@ -30,9 +30,9 @@ set -e # If there is no passwd entry for the container UID, attempt to create one if [ -z "$uidentry" ] ; then if [ -w /etc/passwd ] ; then - echo "$myuid:x:$myuid:$mygid:${SPARK_USER_NAME:-anonymous uid}:$SPARK_HOME:/bin/false" >> /etc/passwd + echo "$myuid:x:$myuid:$mygid:${SPARK_USER_NAME:-anonymous uid}:$SPARK_HOME:/bin/false" >> /etc/passwd else - echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" + echo "Container ENTRYPOINT failed to add passwd entry for anonymous UID" fi fi @@ -59,7 +59,7 @@ fi # If HADOOP_HOME is set and SPARK_DIST_CLASSPATH is not set, set it here so Hadoop jars are available to the executor. # It does not set SPARK_DIST_CLASSPATH if already set, to avoid overriding customizations of this value from elsewhere e.g. Docker/K8s. if [ -n ${HADOOP_HOME} ] && [ -z ${SPARK_DIST_CLASSPATH} ]; then - export SPARK_DIST_CLASSPATH=$($HADOOP_HOME/bin/hadoop classpath) + export SPARK_DIST_CLASSPATH=$($HADOOP_HOME/bin/hadoop classpath) fi if ! [ -z ${HADOOP_CONF_DIR+x} ]; then diff --git a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh index 607bb243458a6..292abe91d35b6 100755 --- a/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh +++ b/resource-managers/kubernetes/integration-tests/dev/dev-run-integration-tests.sh @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -set -xo errexit +set -exo errexit TEST_ROOT_DIR=$(git rev-parse --show-toplevel) DEPLOY_MODE="minikube" @@ -42,6 +42,9 @@ SCALA_VERSION=$("$MVN" help:evaluate -Dexpression=scala.binary.version 2>/dev/nu | grep -v "WARNING"\ | tail -n 1) +export SCALA_VERSION +echo $SCALA_VERSION + # Parse arguments while (( "$#" )); do case $1 in @@ -110,7 +113,8 @@ while (( "$#" )); do shift ;; *) - break + echo "Unexpected command line flag $2 $1." + exit 1 ;; esac shift @@ -164,6 +168,7 @@ properties+=( -Dspark.kubernetes.test.jvmImage=$JVM_IMAGE_NAME -Dspark.kubernetes.test.pythonImage=$PYTHON_IMAGE_NAME -Dspark.kubernetes.test.rImage=$R_IMAGE_NAME + -Dlog4j.logger.org.apache.spark=DEBUG ) $TEST_ROOT_DIR/build/mvn integration-test -f $TEST_ROOT_DIR/pom.xml -pl resource-managers/kubernetes/integration-tests -am -Pscala-$SCALA_VERSION -Pkubernetes -Pkubernetes-integration-tests ${properties[@]} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala new file mode 100644 index 0000000000000..f5eab6e4bbad6 --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.integrationtest + +import org.apache.spark.internal.config.Worker + +private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => + + import DecommissionSuite._ + import KubernetesSuite.k8sTestTag + + test("Test basic decommissioning", k8sTestTag) { + sparkAppConf + .set(Worker.WORKER_DECOMMISSION_ENABLED.key, "true") + .set("spark.kubernetes.pyspark.pythonVersion", "3") + .set("spark.kubernetes.container.image", pyImage) + + runSparkApplicationAndVerifyCompletion( + appResource = PYSPARK_DECOMISSIONING, + mainClass = "", + expectedLogOnCompletion = Seq("decommissioning executor", + "Finished waiting, stopping Spark"), + appArgs = Array.empty[String], + driverPodChecker = doBasicDriverPyPodCheck, + executorPodChecker = doBasicExecutorPyPodCheck, + appLocator = appLocator, + isJVM = false, + decommissioningTest = true) + } +} + +private[spark] object DecommissionSuite { + val TEST_LOCAL_PYSPARK: String = "local:///opt/spark/tests/" + val PYSPARK_DECOMISSIONING: String = TEST_LOCAL_PYSPARK + "decommissioning.py" +} diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 0d4fcccc35cf9..61e1f27b55462 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -42,7 +42,9 @@ import org.apache.spark.internal.config._ class KubernetesSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite with PythonTestsSuite with ClientModeTestsSuite with PodTemplateSuite with PVTestsSuite - with DepsTestsSuite with RTestsSuite with Logging with Eventually with Matchers { + with DepsTestsSuite with DecommissionSuite with RTestsSuite with Logging with Eventually + with Matchers { + import KubernetesSuite._ @@ -254,6 +256,7 @@ class KubernetesSuite extends SparkFunSuite } } + // scalastyle:off argcount protected def runSparkApplicationAndVerifyCompletion( appResource: String, mainClass: String, @@ -264,60 +267,120 @@ class KubernetesSuite extends SparkFunSuite appLocator: String, isJVM: Boolean, pyFiles: Option[String] = None, - executorPatience: Option[(Option[Interval], Option[Timeout])] = None): Unit = { + executorPatience: Option[(Option[Interval], Option[Timeout])] = None, + decommissioningTest: Boolean = false): Unit = { + + // scalastyle:on argcount val appArguments = SparkAppArguments( mainAppResource = appResource, mainClass = mainClass, appArgs = appArgs) - SparkAppLauncher.launch( - appArguments, - sparkAppConf, - TIMEOUT.value.toSeconds.toInt, - sparkHomeDir, - isJVM, - pyFiles) - val driverPod = kubernetesTestComponents.kubernetesClient - .pods() - .withLabel("spark-app-locator", appLocator) - .withLabel("spark-role", "driver") - .list() - .getItems - .get(0) - driverPodChecker(driverPod) val execPods = scala.collection.mutable.Map[String, Pod]() + val (patienceInterval, patienceTimeout) = { + executorPatience match { + case Some(patience) => (patience._1.getOrElse(INTERVAL), patience._2.getOrElse(TIMEOUT)) + case _ => (INTERVAL, TIMEOUT) + } + } + def checkPodReady(namespace: String, name: String) = { + val execPod = kubernetesTestComponents.kubernetesClient + .pods() + .inNamespace(namespace) + .withName(name) + .get() + val resourceStatus = execPod.getStatus + val conditions = resourceStatus.getConditions().asScala + val conditionTypes = conditions.map(_.getType()) + val readyConditions = conditions.filter{cond => cond.getType() == "Ready"} + val result = readyConditions + .map(cond => cond.getStatus() == "True") + .headOption.getOrElse(false) + result + } val execWatcher = kubernetesTestComponents.kubernetesClient .pods() .withLabel("spark-app-locator", appLocator) .withLabel("spark-role", "executor") .watch(new Watcher[Pod] { - logInfo("Beginning watch of executors") + logDebug("Beginning watch of executors") override def onClose(cause: KubernetesClientException): Unit = logInfo("Ending watch of executors") override def eventReceived(action: Watcher.Action, resource: Pod): Unit = { val name = resource.getMetadata.getName + val namespace = resource.getMetadata().getNamespace() action match { - case Action.ADDED | Action.MODIFIED => + case Action.MODIFIED => + execPods(name) = resource + case Action.ADDED => + logDebug(s"Add event received for $name.") execPods(name) = resource + // If testing decommissioning start a thread to simulate + // decommissioning. + if (decommissioningTest && execPods.size == 1) { + // Wait for all the containers in the pod to be running + logDebug("Waiting for first pod to become OK prior to deletion") + Eventually.eventually(patienceTimeout, patienceInterval) { + val result = checkPodReady(namespace, name) + result shouldBe (true) + } + // Sleep a small interval to allow execution of job + logDebug("Sleeping before killing pod.") + Thread.sleep(2000) + // Delete the pod to simulate cluster scale down/migration. + val pod = kubernetesTestComponents.kubernetesClient.pods().withName(name) + pod.delete() + logDebug(s"Triggered pod decom/delete: $name deleted") + } case Action.DELETED | Action.ERROR => execPods.remove(name) } } }) - val (patienceInterval, patienceTimeout) = { - executorPatience match { - case Some(patience) => (patience._1.getOrElse(INTERVAL), patience._2.getOrElse(TIMEOUT)) - case _ => (INTERVAL, TIMEOUT) - } - } + logDebug("Starting Spark K8s job") + SparkAppLauncher.launch( + appArguments, + sparkAppConf, + TIMEOUT.value.toSeconds.toInt, + sparkHomeDir, + isJVM, + pyFiles) + val driverPod = kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "driver") + .list() + .getItems + .get(0) + + driverPodChecker(driverPod) + // If we're testing decommissioning we delete all the executors, but we should have + // an executor at some point. Eventually.eventually(patienceTimeout, patienceInterval) { execPods.values.nonEmpty should be (true) } + // If decommissioning we need to wait and check the executors were removed + if (decommissioningTest) { + // Sleep a small interval to ensure everything is registered. + Thread.sleep(100) + // Wait for the executors to become ready + Eventually.eventually(patienceTimeout, patienceInterval) { + val anyReadyPods = ! execPods.map{ + case (name, resource) => + (name, resource.getMetadata().getNamespace()) + }.filter{ + case (name, namespace) => checkPodReady(namespace, name) + }.isEmpty + val podsEmpty = execPods.values.isEmpty + val podsReadyOrDead = anyReadyPods || podsEmpty + podsReadyOrDead shouldBe (true) + } + } execWatcher.close() execPods.values.foreach(executorPodChecker(_)) - Eventually.eventually(TIMEOUT, patienceInterval) { + Eventually.eventually(patienceTimeout, patienceInterval) { expectedLogOnCompletion.foreach { e => assert(kubernetesTestComponents.kubernetesClient .pods() @@ -425,5 +488,5 @@ private[spark] object KubernetesSuite { val SPARK_REMOTE_MAIN_CLASS: String = "org.apache.spark.examples.SparkRemoteFileTest" val SPARK_DRIVER_MAIN_CLASS: String = "org.apache.spark.examples.DriverSubmissionTest" val TIMEOUT = PatienceConfiguration.Timeout(Span(2, Minutes)) - val INTERVAL = PatienceConfiguration.Interval(Span(2, Seconds)) + val INTERVAL = PatienceConfiguration.Interval(Span(1, Seconds)) } diff --git a/resource-managers/kubernetes/integration-tests/tests/decommissioning.py b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py new file mode 100644 index 0000000000000..f68f24d49763d --- /dev/null +++ b/resource-managers/kubernetes/integration-tests/tests/decommissioning.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import time + +from pyspark.sql import SparkSession + + +if __name__ == "__main__": + """ + Usage: decommissioning + """ + print("Starting decom test") + spark = SparkSession \ + .builder \ + .appName("PyMemoryTest") \ + .getOrCreate() + sc = spark._sc + rdd = sc.parallelize(range(10)) + rdd.collect() + print("Waiting to give nodes time to finish.") + time.sleep(5) + rdd.collect() + print("Waiting some more....") + time.sleep(10) + rdd.collect() + print("Finished waiting, stopping Spark.") + spark.stop() + print("Done, exiting Python") + sys.exit(0) diff --git a/sbin/decommission-slave.sh b/sbin/decommission-slave.sh new file mode 100644 index 0000000000000..4bbf257ff1d3a --- /dev/null +++ b/sbin/decommission-slave.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# A shell script to decommission all workers on a single slave +# +# Environment variables +# +# SPARK_WORKER_INSTANCES The number of worker instances that should be +# running on this slave. Default is 1. + +# Usage: decommission-slave.sh [--block-until-exit] +# Decommissions all slaves on this worker machine + +set -ex + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" + +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ "$SPARK_WORKER_INSTANCES" = "" ]; then + "${SPARK_HOME}/sbin"/spark-daemon.sh decommission org.apache.spark.deploy.worker.Worker 1 +else + for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do + "${SPARK_HOME}/sbin"/spark-daemon.sh decommission org.apache.spark.deploy.worker.Worker $(( $i + 1 )) + done +fi + +# Check if --block-until-exit is set. +# This is done for systems which block on the decomissioning script and on exit +# shut down the entire system (e.g. K8s). +if [ "$1" == "--block-until-exit" ]; then + shift + # For now we only block on the 0th instance if there multiple instances. + instance=$1 + pid="$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid" + wait $pid +fi diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 6de67e039b48f..81f2fd40a706f 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -215,6 +215,21 @@ case $option in fi ;; + (decommission) + + if [ -f $pid ]; then + TARGET_ID="$(cat "$pid")" + if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then + echo "decommissioning $command" + kill -s SIGPWR "$TARGET_ID" + else + echo "no $command to decommission" + fi + else + echo "no $command to decommission" + fi + ;; + (status) if [ -f $pid ]; then From d0f961476031b62bda0d4d41f7248295d651ea92 Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Fri, 14 Feb 2020 21:46:01 +0000 Subject: [PATCH 051/185] [SPARK-30289][SQL] Partitioned by Nested Column for `InMemoryTable` ### What changes were proposed in this pull request? 1. `InMemoryTable` was flatting the nested columns, and then the flatten columns was used to look up the indices which is not correct. This PR implements partitioned by nested column for `InMemoryTable`. ### Why are the changes needed? This PR implements partitioned by nested column for `InMemoryTable`, so we can test this features in DSv2 ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Existing unit tests and new tests. Closes #26929 from dbtsai/addTests. Authored-by: DB Tsai Signed-off-by: DB Tsai --- .../spark/sql/connector/InMemoryTable.scala | 35 ++++++-- .../spark/sql/DataFrameWriterV2Suite.scala | 86 ++++++++++++++++++- 2 files changed, 114 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index c9e4e0aad5704..0187ae31e2d1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -26,7 +26,7 @@ import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog._ -import org.apache.spark.sql.connector.expressions.{IdentityTransform, Transform} +import org.apache.spark.sql.connector.expressions.{IdentityTransform, NamedReference, Transform} import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull} @@ -59,10 +59,30 @@ class InMemoryTable( def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq - private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) - private val partIndexes = partFieldNames.map(schema.fieldIndex) + private val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref => + schema.findNestedField(ref.fieldNames(), includeCollections = false) match { + case Some(_) => ref.fieldNames() + case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") + } + } - private def getKey(row: InternalRow): Seq[Any] = partIndexes.map(row.toSeq(schema)(_)) + private def getKey(row: InternalRow): Seq[Any] = { + def extractor(fieldNames: Array[String], schema: StructType, row: InternalRow): Any = { + val index = schema.fieldIndex(fieldNames(0)) + val value = row.toSeq(schema).apply(index) + if (fieldNames.length > 1) { + (value, schema(index).dataType) match { + case (row: InternalRow, nestedSchema: StructType) => + extractor(fieldNames.drop(1), nestedSchema, row) + case (_, dataType) => + throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") + } + } else { + value + } + } + partCols.map(fieldNames => extractor(fieldNames, schema, row)) + } def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { data.foreach(_.rows.foreach { row => @@ -146,8 +166,10 @@ class InMemoryTable( } private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - val deleteKeys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + val deleteKeys = InMemoryTable.filtersToKeys( + dataMap.keys, partCols.map(_.toSeq.quoted), filters) dataMap --= deleteKeys withData(messages.map(_.asInstanceOf[BufferedRows])) } @@ -161,7 +183,8 @@ class InMemoryTable( } override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { - dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters) + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index d49dc58e93ddb..cd157086a8b8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -17,20 +17,24 @@ package org.apache.spark.sql +import java.sql.Timestamp + import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} -import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.connector.{InMemoryTable, InMemoryTableCatalog} import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types.TimestampType import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -550,4 +554,84 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo assert(replaced.partitioning.isEmpty) assert(replaced.properties === defaultOwnership.asJava) } + + test("SPARK-30289 Create: partitioned by nested column") { + val schema = new StructType().add("ts", new StructType() + .add("created", TimestampType) + .add("modified", TimestampType) + .add("timezone", StringType)) + + val data = Seq( + Row(Row(Timestamp.valueOf("2019-06-01 10:00:00"), Timestamp.valueOf("2019-09-02 07:00:00"), + "America/Los_Angeles")), + Row(Row(Timestamp.valueOf("2019-08-26 18:00:00"), Timestamp.valueOf("2019-09-26 18:00:00"), + "America/Los_Angeles")), + Row(Row(Timestamp.valueOf("2018-11-23 18:00:00"), Timestamp.valueOf("2018-12-22 18:00:00"), + "America/New_York"))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema) + + df.writeTo("testcat.table_name") + .partitionedBy($"ts.timezone") + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + .asInstanceOf[InMemoryTable] + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(IdentityTransform(FieldReference(Array("ts", "timezone"))))) + checkAnswer(spark.table(table.name), data) + assert(table.dataMap.toArray.length == 2) + assert(table.dataMap(Seq(UTF8String.fromString("America/Los_Angeles"))).rows.size == 2) + assert(table.dataMap(Seq(UTF8String.fromString("America/New_York"))).rows.size == 1) + + // TODO: `DataSourceV2Strategy` can not translate nested fields into source filter yet + // so the following sql will fail. + // sql("DELETE FROM testcat.table_name WHERE ts.timezone = \"America/Los_Angeles\"") + } + + test("SPARK-30289 Create: partitioned by multiple transforms on nested columns") { + spark.table("source") + .withColumn("ts", struct( + lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created", + lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", + lit("America/Los_Angeles") as "timezone")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy( + years($"ts.created"), months($"ts.created"), days($"ts.created"), hours($"ts.created"), + years($"ts.modified"), months($"ts.modified"), days($"ts.modified"), hours($"ts.modified") + ) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq( + YearsTransform(FieldReference(Array("ts", "created"))), + MonthsTransform(FieldReference(Array("ts", "created"))), + DaysTransform(FieldReference(Array("ts", "created"))), + HoursTransform(FieldReference(Array("ts", "created"))), + YearsTransform(FieldReference(Array("ts", "modified"))), + MonthsTransform(FieldReference(Array("ts", "modified"))), + DaysTransform(FieldReference(Array("ts", "modified"))), + HoursTransform(FieldReference(Array("ts", "modified"))))) + } + + test("SPARK-30289 Create: partitioned by bucket(4, ts.timezone)") { + spark.table("source") + .withColumn("ts", struct( + lit("2019-06-01 10:00:00.000000").cast("timestamp") as "created", + lit("2019-09-02 07:00:00.000000").cast("timestamp") as "modified", + lit("America/Los_Angeles") as "timezone")) + .writeTo("testcat.table_name") + .tableProperty("allow-unsupported-transforms", "true") + .partitionedBy(bucket(4, $"ts.timezone")) + .create() + + val table = catalog("testcat").loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name === "testcat.table_name") + assert(table.partitioning === Seq(BucketTransform(LiteralValue(4, IntegerType), + Seq(FieldReference(Seq("ts", "timezone")))))) + } } From 8b73b92aadd685b29ef3d9b33366f5e1fd3dae99 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Sat, 15 Feb 2020 19:49:58 +0800 Subject: [PATCH 052/185] [SPARK-30826][SQL] Respect reference case in `StringStartsWith` pushed down to parquet ### What changes were proposed in this pull request? In the PR, I propose to convert the attribute name of `StringStartsWith` pushed down to the Parquet datasource to column reference via the `nameToParquetField` map. Similar conversions are performed for other source filters pushed down to parquet. ### Why are the changes needed? This fixes the bug described in [SPARK-30826](https://issues.apache.org/jira/browse/SPARK-30826). The query from an external table: ```sql CREATE TABLE t1 (col STRING) USING parquet OPTIONS (path '$path') ``` created on top of written parquet files by `Seq("42").toDF("COL").write.parquet(path)` returns wrong empty result: ```scala spark.sql("SELECT * FROM t1 WHERE col LIKE '4%'").show +---+ |col| +---+ +---+ ``` ### Does this PR introduce any user-facing change? Yes. After the changes the result is correct for the example above: ```scala spark.sql("SELECT * FROM t1 WHERE col LIKE '4%'").show +---+ |col| +---+ | 42| +---+ ``` ### How was this patch tested? Added a test to `ParquetFilterSuite` Closes #27574 from MaxGekk/parquet-StringStartsWith-case-sens. Authored-by: Maxim Gekk Signed-off-by: Wenchen Fan --- .../datasources/parquet/ParquetFilters.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index b9b86adb438e6..948a120e0d6e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -591,7 +591,7 @@ class ParquetFilters( case sources.StringStartsWith(name, prefix) if pushDownStartWith && canMakeFilterOn(name, prefix) => Option(prefix).map { v => - FilterApi.userDefined(binaryColumn(name), + FilterApi.userDefined(binaryColumn(nameToParquetField(name).fieldName), new UserDefinedPredicate[Binary] with Serializable { private val strToBinary = Binary.fromReusedByteArray(v.getBytes) private val size = strToBinary.length diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 286bb1e920266..4e0c1c2dbe601 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -1390,6 +1390,27 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } } + + test("SPARK-30826: case insensitivity of StringStartsWith attribute") { + import testImplicits._ + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTable("t1") { + withTempPath { dir => + val path = dir.toURI.toString + Seq("42").toDF("COL").write.parquet(path) + spark.sql( + s""" + |CREATE TABLE t1 (col STRING) + |USING parquet + |OPTIONS (path '$path') + """.stripMargin) + checkAnswer( + spark.sql("SELECT * FROM t1 WHERE col LIKE '4%'"), + Row("42")) + } + } + } + } } class ParquetV1FilterSuite extends ParquetFilterSuite { From f5238ea6cb0d2cfa69ae0488df94b29cc50065e0 Mon Sep 17 00:00:00 2001 From: "Wu, Xiaochang" Date: Sun, 16 Feb 2020 09:51:02 -0600 Subject: [PATCH 053/185] [GRAPHX][MINOR] Fix typo setRest => setDest ### What changes were proposed in this pull request? Fix typo def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) to def setDest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) ### Why are the changes needed? Typo ### Does this PR introduce any user-facing change? No ### How was this patch tested? N/A Closes #27594 from xwu99/fix-graphx-setDest. Authored-by: Wu, Xiaochang Signed-off-by: Sean Owen --- .../scala/org/apache/spark/graphx/impl/EdgePartition.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 8d03112a1c3dc..c0a2ba67d2942 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -465,7 +465,7 @@ class EdgePartition[ if (edgeIsActive) { val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] - ctx.setRest(dstId, localDstId, dstAttr, data(pos)) + ctx.setDest(dstId, localDstId, dstAttr, data(pos)) sendMsg(ctx) } pos += 1 @@ -511,7 +511,7 @@ private class AggregatingEdgeContext[VD, ED, A]( _srcAttr = srcAttr } - def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED): Unit = { + def setDest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED): Unit = { _dstId = dstId _localDstId = localDstId _dstAttr = dstAttr From 0a03e7e679771da8556fae72b35edf21ae71ac44 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Sun, 16 Feb 2020 09:53:12 -0600 Subject: [PATCH 054/185] [SPARK-30691][SQL][DOC][FOLLOW-UP] Make link names exactly the same as the side bar names ### What changes were proposed in this pull request? Make link names exactly the same as the side bar names ### Why are the changes needed? Make doc look better ### Does this PR introduce any user-facing change? before: ![image](https://user-images.githubusercontent.com/13592258/74578603-ad300100-4f4a-11ea-8430-11fccf31eab4.png) after: ![image](https://user-images.githubusercontent.com/13592258/74578670-eff1d900-4f4a-11ea-97d8-5908c0e50e95.png) ### How was this patch tested? Manually build and check the docs Closes #27591 from huaxingao/spark-doc-followup. Authored-by: Huaxin Gao Signed-off-by: Sean Owen --- docs/_data/menu-sql.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index 1e343f630f88e..38a5cf61245a6 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -157,12 +157,12 @@ - text: Auxiliary Statements url: sql-ref-syntax-aux.html subitems: - - text: Analyze statement + - text: ANALYZE url: sql-ref-syntax-aux-analyze.html subitems: - text: ANALYZE TABLE url: sql-ref-syntax-aux-analyze-table.html - - text: Caching statements + - text: CACHE url: sql-ref-syntax-aux-cache.html subitems: - text: CACHE TABLE @@ -175,7 +175,7 @@ url: sql-ref-syntax-aux-refresh-table.html - text: REFRESH url: sql-ref-syntax-aux-cache-refresh.md - - text: Describe Commands + - text: DESCRIBE url: sql-ref-syntax-aux-describe.html subitems: - text: DESCRIBE DATABASE @@ -186,7 +186,7 @@ url: sql-ref-syntax-aux-describe-function.html - text: DESCRIBE QUERY url: sql-ref-syntax-aux-describe-query.html - - text: Show commands + - text: SHOW url: sql-ref-syntax-aux-show.html subitems: - text: SHOW COLUMNS @@ -205,14 +205,14 @@ url: sql-ref-syntax-aux-show-partitions.html - text: SHOW CREATE TABLE url: sql-ref-syntax-aux-show-create-table.html - - text: Configuration Management Commands + - text: CONFIGURATION MANAGEMENT url: sql-ref-syntax-aux-conf-mgmt.html subitems: - text: SET url: sql-ref-syntax-aux-conf-mgmt-set.html - text: RESET url: sql-ref-syntax-aux-conf-mgmt-reset.html - - text: Resource Management Commands + - text: RESOURCE MANAGEMENT url: sql-ref-syntax-aux-resource-mgmt.html subitems: - text: ADD FILE From 01cc852982cd065e08f9a652c14a0514f49fb662 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Sun, 16 Feb 2020 09:55:03 -0600 Subject: [PATCH 055/185] [SPARK-30803][DOCS] Fix the home page link for Scala API document ### What changes were proposed in this pull request? Change the link to the Scala API document. ``` $ git grep "#org.apache.spark.package" docs/_layouts/global.html:
  • Scala
  • docs/index.md:* [Spark Scala API (Scaladoc)](api/scala/index.html#org.apache.spark.package) docs/rdd-programming-guide.md:[Scala](api/scala/#org.apache.spark.package), [Java](api/java/), [Python](api/python/) and [R](api/R/). ``` ### Why are the changes needed? The home page link for Scala API document is incorrect after upgrade to 3.0 ### Does this PR introduce any user-facing change? Document UI change only. ### How was this patch tested? Local test, attach screenshots below: Before: ![image](https://user-images.githubusercontent.com/4833765/74335713-c2385300-4dd7-11ea-95d8-f5a3639d2578.png) After: ![image](https://user-images.githubusercontent.com/4833765/74335727-cbc1bb00-4dd7-11ea-89d9-4dcc1310e679.png) Closes #27549 from xuanyuanking/scala-doc. Authored-by: Yuanjian Li Signed-off-by: Sean Owen --- docs/_layouts/global.html | 2 +- docs/configuration.md | 8 +- docs/graphx-programming-guide.md | 68 +++++++------- docs/index.md | 2 +- docs/ml-advanced.md | 10 +- docs/ml-classification-regression.md | 40 ++++---- docs/ml-clustering.md | 10 +- docs/ml-collaborative-filtering.md | 2 +- docs/ml-datasource.md | 4 +- docs/ml-features.md | 92 +++++++++---------- docs/ml-frequent-pattern-mining.md | 4 +- docs/ml-migration-guide.md | 36 ++++---- docs/ml-pipeline.md | 10 +- docs/ml-statistics.md | 8 +- docs/ml-tuning.md | 18 ++-- docs/mllib-clustering.md | 26 +++--- docs/mllib-collaborative-filtering.md | 4 +- docs/mllib-data-types.md | 48 +++++----- docs/mllib-decision-tree.md | 10 +- docs/mllib-dimensionality-reduction.md | 6 +- docs/mllib-ensembles.md | 10 +- docs/mllib-evaluation-metrics.md | 8 +- docs/mllib-feature-extraction.md | 34 +++---- docs/mllib-frequent-pattern-mining.md | 14 +-- docs/mllib-isotonic-regression.md | 2 +- docs/mllib-linear-methods.md | 22 ++--- docs/mllib-naive-bayes.md | 8 +- docs/mllib-optimization.md | 14 +-- docs/mllib-pmml-model-export.md | 2 +- docs/mllib-statistics.md | 28 +++--- docs/quick-start.md | 2 +- docs/rdd-programming-guide.md | 28 +++--- docs/sql-data-sources-generic-options.md | 2 +- docs/sql-data-sources-jdbc.md | 2 +- docs/sql-data-sources-json.md | 2 +- docs/sql-getting-started.md | 16 ++-- docs/sql-migration-guide.md | 4 +- docs/sql-programming-guide.md | 2 +- docs/sql-ref-syntax-aux-analyze-table.md | 2 +- docs/sql-ref-syntax-aux-cache-refresh.md | 2 +- docs/sql-ref-syntax-aux-refresh-table.md | 2 +- docs/sql-ref-syntax-aux-resource-mgmt.md | 2 +- docs/sql-ref-syntax-aux-show-tables.md | 2 +- docs/sql-ref-syntax-aux-show.md | 2 +- docs/sql-ref-syntax-ddl-drop-database.md | 2 +- docs/sql-ref-syntax-ddl-drop-function.md | 2 +- ...tax-dml-insert-overwrite-directory-hive.md | 2 +- ...f-syntax-dml-insert-overwrite-directory.md | 2 +- docs/sql-ref-syntax-dml.md | 2 +- docs/sql-ref-syntax-qry-select-clusterby.md | 2 +- ...sql-ref-syntax-qry-select-distribute-by.md | 2 +- docs/sql-ref-syntax-qry-select-sortby.md | 2 +- docs/sql-ref-syntax-qry-select.md | 2 +- docs/streaming-custom-receivers.md | 2 +- docs/streaming-kafka-integration.md | 2 +- docs/streaming-kinesis-integration.md | 2 +- docs/streaming-programming-guide.md | 42 ++++----- .../structured-streaming-programming-guide.md | 22 ++--- docs/tuning.md | 2 +- 59 files changed, 355 insertions(+), 355 deletions(-) diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index d5fb18bfb06c0..d05ac6bbe129d 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -82,7 +82,7 @@