diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 34f29563c7ff..132ced820e9a 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1793,7 +1793,7 @@ version ; operatorPipeRightSide - : selectClause windowClause? + : selectClause aggregationClause? windowClause? | EXTEND extendList=namedExpressionSeq | SET operatorPipeSetAssignmentSeq | DROP identifierSeq diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala index 2ee68663ad2f..b2bb949c9e5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/pipeOperators.scala @@ -65,7 +65,7 @@ object EliminatePipeOperators extends Rule[LogicalPlan] { * Validates and strips PipeExpression nodes from a logical plan once the child expressions are * resolved. */ -object ValidateAndStripPipeExpressions extends Rule[LogicalPlan] { +case object ValidateAndStripPipeExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsPattern(PIPE_EXPRESSION), ruleId) { case node: LogicalPlan => @@ -78,8 +78,13 @@ object ValidateAndStripPipeExpressions extends Rule[LogicalPlan] { throw QueryCompilationErrors .pipeOperatorAggregateExpressionContainsNoAggregateFunction(p.child) } else if (!p.isAggregate) { - firstAggregateFunction.foreach { a => - throw QueryCompilationErrors.pipeOperatorContainsAggregateFunction(a, p.clause) + // For non-aggregate clauses, only allow aggregate functions in SELECT. + // All other clauses (EXTEND, SET, etc.) disallow aggregates. + val aggregateAllowed = p.clause == PipeOperators.selectClause + if (!aggregateAllowed) { + firstAggregateFunction.foreach { a => + throw QueryCompilationErrors.pipeOperatorContainsAggregateFunction(a, p.clause) + } } } p.child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d1d4a6b8c980..abc282e9c488 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -6605,7 +6605,7 @@ class AstBuilder extends DataTypeAstBuilder private def visitOperatorPipeRightSide( ctx: OperatorPipeRightSideContext, left: LogicalPlan): LogicalPlan = { - if (!SQLConf.get.getConf(SQLConf.OPERATOR_PIPE_SYNTAX_ENABLED)) { + if (!conf.getConf(SQLConf.OPERATOR_PIPE_SYNTAX_ENABLED)) { operationNotAllowed("Operator pipe SQL syntax using |>", ctx) } Option(ctx.selectClause).map { c => @@ -6614,7 +6614,7 @@ class AstBuilder extends DataTypeAstBuilder selectClause = c, lateralView = new java.util.ArrayList[LateralViewContext](), whereClause = null, - aggregationClause = null, + aggregationClause = ctx.aggregationClause, havingClause = null, windowClause = ctx.windowClause, relation = left, diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index 439d701018a2..9aca65a057a9 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -632,109 +632,6 @@ Repartition 3, true +- Relation spark_catalog.default.t[x#x,y#x] csv --- !query -table t -|> select sum(x) as result --- !query analysis -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(x#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 19, - "stopIndex" : 24, - "fragment" : "sum(x)" - } ] -} - - --- !query -table t -|> select y, length(y) + sum(x) as result --- !query analysis -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(x#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 34, - "stopIndex" : 39, - "fragment" : "sum(x)" - } ] -} - - --- !query -from t -|> select sum(x) --- !query analysis -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(x#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 18, - "stopIndex" : 23, - "fragment" : "sum(x)" - } ] -} - - --- !query -from t as t_alias -|> select y, sum(x) --- !query analysis -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(x#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 32, - "stopIndex" : 37, - "fragment" : "sum(x)" - } ] -} - - --- !query -from t as t_alias -|> select y, sum(x) group by y --- !query analysis -org.apache.spark.sql.catalyst.parser.ParseException -{ - "errorClass" : "PARSE_SYNTAX_ERROR", - "sqlState" : "42601", - "messageParameters" : { - "error" : "'group'", - "hint" : "" - } -} - - -- !query table t |> extend 1 as z @@ -3683,28 +3580,6 @@ org.apache.spark.sql.AnalysisException } --- !query -table other -|> select sum(a) as result --- !query analysis -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(a#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 23, - "stopIndex" : 28, - "fragment" : "sum(a)" - } ] -} - - -- !query table other |> aggregate @@ -4947,6 +4822,163 @@ Project [x#x, y#x] +- Relation spark_catalog.default.t[x#x,y#x] csv +-- !query +table other +|> select sum(a) as result +-- !query analysis +Aggregate [sum(a#x) AS result#xL] ++- SubqueryAlias spark_catalog.default.other + +- Relation spark_catalog.default.other[a#x,b#x] json + + +-- !query +table other +|> select sum(a) as total_a, avg(b) as avg_b +-- !query analysis +Aggregate [sum(a#x) AS total_a#xL, avg(b#x) AS avg_b#x] ++- SubqueryAlias spark_catalog.default.other + +- Relation spark_catalog.default.other[a#x,b#x] json + + +-- !query +table other +|> where b > 1 +|> select sum(a) as result +-- !query analysis +Aggregate [sum(a#x) AS result#xL] ++- Filter (b#x > 1) + +- PipeOperator + +- SubqueryAlias spark_catalog.default.other + +- Relation spark_catalog.default.other[a#x,b#x] json + + +-- !query +table other +|> select sum(a) as total_a +|> select total_a * 2 as doubled +-- !query analysis +Project [(total_a#xL * cast(2 as bigint)) AS doubled#xL] ++- Aggregate [sum(a#x) AS total_a#xL] + +- SubqueryAlias spark_catalog.default.other + +- Relation spark_catalog.default.other[a#x,b#x] json + + +-- !query +table other +|> select a, sum(b) as sum_b group by a +-- !query analysis +Aggregate [a#x], [a#x, sum(b#x) AS sum_b#xL] ++- SubqueryAlias spark_catalog.default.other + +- Relation spark_catalog.default.other[a#x,b#x] json + + +-- !query +select 1 as x, 2 as y, 3 as z +|> select x, y, sum(z) as total group by x, y +-- !query analysis +Aggregate [x#x, y#x], [x#x, y#x, sum(z#x) AS total#xL] ++- Project [1 AS x#x, 2 AS y#x, 3 AS z#x] + +- OneRowRelation + + +-- !query +table other +|> select a, sum(b) as sum_b group by 1 +-- !query analysis +Aggregate [a#x], [a#x, sum(b#x) AS sum_b#xL] ++- SubqueryAlias spark_catalog.default.other + +- Relation spark_catalog.default.other[a#x,b#x] json + + +-- !query +table other +|> select a, sum(b) as sum_b group by a +|> where sum_b > 1 +-- !query analysis +Filter (sum_b#xL > cast(1 as bigint)) ++- PipeOperator + +- Aggregate [a#x], [a#x, sum(b#x) AS sum_b#xL] + +- SubqueryAlias spark_catalog.default.other + +- Relation spark_catalog.default.other[a#x,b#x] json + + +-- !query +select 1 as x, 2 as y +|> select x + 1 as x_plus_one, sum(y) as sum_y group by x + 1 +-- !query analysis +Aggregate [(x#x + 1)], [(x#x + 1) AS x_plus_one#x, sum(y#x) AS sum_y#xL] ++- Project [1 AS x#x, 2 AS y#x] + +- OneRowRelation + + +-- !query +table other +|> select a, sum(b) as sum_b group by b +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"a\"", + "expressionAnyValue" : "\"any_value(a)\"" + } +} + + +-- !query +table other +|> extend sum(a) as total_a +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", + "sqlState" : "0A000", + "messageParameters" : { + "clause" : "EXTEND", + "expr" : "sum(a#x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 23, + "stopIndex" : 28, + "fragment" : "sum(a)" + } ] +} + + +-- !query +table other +|> where sum(a) > 5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(a) > 5)\"", + "expressionList" : "sum(spark_catalog.default.other.a)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 31, + "fragment" : "table other\n|> where sum(a) > 5" + } ] +} + + +-- !query +table other +|> aggregate sum(a) as total_a +-- !query analysis +Aggregate [sum(a#x) AS total_a#xL] ++- SubqueryAlias spark_catalog.default.other + +- Relation spark_catalog.default.other[a#x,b#x] json + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 7ad18d73a9c7..15db8d131be1 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -215,25 +215,6 @@ table t table t |> select /*+ repartition(3) */ all x; --- SELECT operators: negative tests. ---------------------------------------- - --- Aggregate functions are not allowed in the pipe operator SELECT list. -table t -|> select sum(x) as result; - -table t -|> select y, length(y) + sum(x) as result; - -from t -|> select sum(x); - -from t as t_alias -|> select y, sum(x); - -from t as t_alias -|> select y, sum(x) group by y; - -- EXTEND operators: positive tests. ------------------------------------ @@ -1157,10 +1138,6 @@ select 1 as x, 2 as y table other |> aggregate a; --- Using aggregate functions without the AGGREGATE keyword is not allowed. -table other -|> select sum(a) as result; - -- The AGGREGATE keyword requires a GROUP BY clause and/or aggregation function(s). table other |> aggregate; @@ -1846,6 +1823,68 @@ set spark.sql.parser.singleCharacterPipeOperator.enabled=true; table t | select x, y; +-- Aggregates in SELECT: positive tests. +------------------------------------------ +-- Aggregate functions can be used in |> SELECT without requiring the |> AGGREGATE keyword. + +-- Aggregates in SELECT. +table other +|> select sum(a) as result; + +-- Aggregates in SELECT with multiple aggregate functions. +table other +|> select sum(a) as total_a, avg(b) as avg_b; + +-- Aggregates in SELECT with WHERE clause. +table other +|> where b > 1 +|> select sum(a) as result; + +-- Aggregates in SELECT with chaining. +table other +|> select sum(a) as total_a +|> select total_a * 2 as doubled; + +-- Mixed aggregates and non-aggregates in SELECT (should work like regular aggregation). +table other +|> select a, sum(b) as sum_b group by a; + +-- Multiple grouping columns in SELECT. +select 1 as x, 2 as y, 3 as z +|> select x, y, sum(z) as total group by x, y; + +-- GROUP BY with ordinal position referring to input column. +table other +|> select a, sum(b) as sum_b group by 1; + +-- Chaining: GROUP BY followed by WHERE on aggregated result. +table other +|> select a, sum(b) as sum_b group by a +|> where sum_b > 1; + +-- GROUP BY with expression and alias. +select 1 as x, 2 as y +|> select x + 1 as x_plus_one, sum(y) as sum_y group by x + 1; + +-- Non-aggregated column without being in GROUP BY should fail. +table other +|> select a, sum(b) as sum_b group by b; + +-- Aggregates in SELECT: negative tests. +------------------------------------------ + +-- Aggregates in EXTEND are not allowed. +table other +|> extend sum(a) as total_a; + +-- Aggregates in WHERE are not allowed. +table other +|> where sum(a) > 5; + +-- The |> AGGREGATE keyword also works for aggregation. +table other +|> aggregate sum(a) as total_a; + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index 3943623aa998..502dc7f057d1 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -579,119 +579,6 @@ struct 1 --- !query -table t -|> select sum(x) as result --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(x#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 19, - "stopIndex" : 24, - "fragment" : "sum(x)" - } ] -} - - --- !query -table t -|> select y, length(y) + sum(x) as result --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(x#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 34, - "stopIndex" : 39, - "fragment" : "sum(x)" - } ] -} - - --- !query -from t -|> select sum(x) --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(x#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 18, - "stopIndex" : 23, - "fragment" : "sum(x)" - } ] -} - - --- !query -from t as t_alias -|> select y, sum(x) --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(x#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 32, - "stopIndex" : 37, - "fragment" : "sum(x)" - } ] -} - - --- !query -from t as t_alias -|> select y, sum(x) group by y --- !query schema -struct<> --- !query output -org.apache.spark.sql.catalyst.parser.ParseException -{ - "errorClass" : "PARSE_SYNTAX_ERROR", - "sqlState" : "42601", - "messageParameters" : { - "error" : "'group'", - "hint" : "" - } -} - - -- !query table t |> extend 1 as z @@ -3325,30 +3212,6 @@ org.apache.spark.sql.AnalysisException } --- !query -table other -|> select sum(a) as result --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", - "sqlState" : "0A000", - "messageParameters" : { - "clause" : "SELECT", - "expr" : "sum(a#x)" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 23, - "stopIndex" : 28, - "fragment" : "sum(a)" - } ] -} - - -- !query table other |> aggregate @@ -4465,6 +4328,167 @@ struct 1 def +-- !query +table other +|> select sum(a) as result +-- !query schema +struct +-- !query output +4 + + +-- !query +table other +|> select sum(a) as total_a, avg(b) as avg_b +-- !query schema +struct +-- !query output +4 2.3333333333333335 + + +-- !query +table other +|> where b > 1 +|> select sum(a) as result +-- !query schema +struct +-- !query output +3 + + +-- !query +table other +|> select sum(a) as total_a +|> select total_a * 2 as doubled +-- !query schema +struct +-- !query output +8 + + +-- !query +table other +|> select a, sum(b) as sum_b group by a +-- !query schema +struct +-- !query output +1 3 +2 4 + + +-- !query +select 1 as x, 2 as y, 3 as z +|> select x, y, sum(z) as total group by x, y +-- !query schema +struct +-- !query output +1 2 3 + + +-- !query +table other +|> select a, sum(b) as sum_b group by 1 +-- !query schema +struct +-- !query output +1 3 +2 4 + + +-- !query +table other +|> select a, sum(b) as sum_b group by a +|> where sum_b > 1 +-- !query schema +struct +-- !query output +1 3 +2 4 + + +-- !query +select 1 as x, 2 as y +|> select x + 1 as x_plus_one, sum(y) as sum_y group by x + 1 +-- !query schema +struct +-- !query output +2 2 + + +-- !query +table other +|> select a, sum(b) as sum_b group by b +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"a\"", + "expressionAnyValue" : "\"any_value(a)\"" + } +} + + +-- !query +table other +|> extend sum(a) as total_a +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "PIPE_OPERATOR_CONTAINS_AGGREGATE_FUNCTION", + "sqlState" : "0A000", + "messageParameters" : { + "clause" : "EXTEND", + "expr" : "sum(a#x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 23, + "stopIndex" : 28, + "fragment" : "sum(a)" + } ] +} + + +-- !query +table other +|> where sum(a) > 5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(a) > 5)\"", + "expressionList" : "sum(spark_catalog.default.other.a)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 31, + "fragment" : "table other\n|> where sum(a) > 5" + } ] +} + + +-- !query +table other +|> aggregate sum(a) as total_a +-- !query schema +struct +-- !query output +4 + + -- !query drop table t -- !query schema