diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f13dde773496a..5034d33f6717c 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1488,6 +1488,7 @@ version operatorPipeRightSide : selectClause + | whereClause ; // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cb0e0e35c3704..040fcfa99e8f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5854,7 +5854,20 @@ class AstBuilder extends DataTypeAstBuilder windowClause = null, relation = left, isPipeOperatorSelect = true) - }.get + }.getOrElse(Option(ctx.whereClause).map { c => + // Add a table subquery boundary between the new filter and the input plan if one does not + // already exist. This helps the analyzer behave as if we had added the WHERE clause after a + // table subquery containing the input plan. + val withSubqueryAlias = left match { + case s: SubqueryAlias => + s + case u: UnresolvedRelation => + u + case _ => + SubqueryAlias(SubqueryAlias.generateSubqueryName(), left) + } + withWhereClause(c, withSubqueryAlias) + }.get) } /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index ab0635fef048b..c44ce153a2f41 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -255,6 +255,55 @@ Distinct +- Relation spark_catalog.default.t[x#x,y#x] csv +-- !query +table t +|> select * +-- !query analysis +Project [x#x, y#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select * except (y) +-- !query analysis +Project [x#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query analysis +Repartition 3, true ++- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query analysis +Repartition 3, true ++- Distinct + +- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query analysis +Repartition 3, true ++- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + -- !query table t |> select sum(x) as result @@ -297,6 +346,229 @@ org.apache.spark.sql.AnalysisException } +-- !query +table t +|> where true +-- !query analysis +Filter true ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +-- !query analysis +Filter ((x#x + length(y#x)) < 4) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query analysis +Filter ((x#x + length(y#x)) < 3) ++- SubqueryAlias __auto_generated_subquery_name + +- Filter ((x#x + length(y#x)) < 4) + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Aggregate [x#x], [x#x, sum(length(y#x)) AS sum_len#xL] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query analysis +Filter (col#x.i1 = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Project [col#x] + +- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query analysis +Filter (col#x.i1 = 2) ++- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query analysis +Filter exists#x [x#x] +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Project [a#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query analysis +Filter (scalar-subquery#x [x#x] = 1) +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Aggregate [any_value(a#x, false) AS any_value(a)#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 7d0966e7f2095..49a72137ee047 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -12,7 +12,7 @@ drop table if exists st; create table st(x int, col struct) using parquet; insert into st values (1, (2, 3)); --- Selection operators: positive tests. +-- SELECT operators: positive tests. --------------------------------------- -- Selecting a constant. @@ -85,7 +85,24 @@ table t table t |> select distinct x, y; --- Selection operators: negative tests. +-- SELECT * is supported. +table t +|> select *; + +table t +|> select * except (y); + +-- Hints are supported. +table t +|> select /*+ repartition(3) */ *; + +table t +|> select /*+ repartition(3) */ distinct x; + +table t +|> select /*+ repartition(3) */ all x; + +-- SELECT operators: negative tests. --------------------------------------- -- Aggregate functions are not allowed in the pipe operator SELECT list. @@ -95,6 +112,79 @@ table t table t |> select y, length(y) + sum(x) as result; +-- WHERE operators: positive tests. +----------------------------------- + +-- Filtering with a constant predicate. +table t +|> where true; + +-- Filtering with a predicate based on attributes from the input relation. +table t +|> where x + length(y) < 4; + +-- Two consecutive filters are allowed. +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3; + +-- It is possible to use the WHERE operator instead of the HAVING clause when processing the result +-- of aggregations. For example, this WHERE operator is equivalent to the normal SQL "HAVING x = 1". +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1; + +-- Filtering by referring to the table or table subquery alias. +table t +|> where t.x = 1; + +table t +|> where spark_catalog.default.t.x = 1; + +-- Filtering using struct fields. +(select col from st) +|> where col.i1 = 1; + +table st +|> where st.col.i1 = 2; + +-- Expression subqueries in the WHERE clause. +table t +|> where exists (select a from other where x = a limit 1); + +-- Aggregations are allowed within expression subqueries in the pipe operator WHERE clause as long +-- no aggregate functions exist in the top-level expression predicate. +table t +|> where (select any_value(a) from other where x = a limit 1) = 1; + +-- WHERE operators: negative tests. +----------------------------------- + +-- Aggregate functions are not allowed in the top-level WHERE predicate. +-- (Note: to implement this behavior, perform the aggregation first separately and then add a +-- pipe-operator WHERE clause referring to the result of aggregate expression(s) therein). +table t +|> where sum(x) = 1; + +table t +|> where y = 'abc' or length(y) + sum(x) = 1; + +-- Window functions are not allowed in the WHERE clause (pipe operators or otherwise). +table t +|> where first_value(x) over (partition by y) = 1; + +select * from t where first_value(x) over (partition by y) = 1; + +-- Pipe operators may only refer to attributes produced as output from the directly-preceding +-- pipe operator, not from earlier ones. +table t +|> select x, length(y) as z +|> where x + length(y) < 4; + +-- If the WHERE clause wants to filter rows produced by an aggregation, it is not valid to try to +-- refer to the aggregate functions directly; it is necessary to use aliases instead. +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3; + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index 7e0b7912105c2..38436b0941034 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -238,6 +238,56 @@ struct 1 def +-- !query +table t +|> select * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select * except (y) +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query schema +struct +-- !query output +0 +1 + + -- !query table t |> select sum(x) as result @@ -284,6 +334,224 @@ org.apache.spark.sql.AnalysisException } +-- !query +table t +|> where true +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> where x + length(y) < 4 +-- !query schema +struct +-- !query output +0 abc + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query schema +struct +-- !query output + + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query schema +struct +-- !query output +1 3 + + +-- !query +table t +|> where t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query schema +struct> +-- !query output + + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query schema +struct> +-- !query output +1 {"i1":2,"i2":3} + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + -- !query drop table t -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index a80444feb68ae..ab949c5a21e44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} @@ -895,6 +895,16 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { checkPipeSelect("TABLE t |> SELECT 1 AS X") checkPipeSelect("TABLE t |> SELECT 1 AS X, 2 AS Y |> SELECT X + Y AS Z") checkPipeSelect("VALUES (0), (1) tab(col) |> SELECT col * 2 AS result") + // Basic WHERE operators. + def checkPipeWhere(query: String): Unit = { + val plan: LogicalPlan = parser.parsePlan(query) + assert(plan.containsPattern(FILTER)) + assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) + } + checkPipeWhere("TABLE t |> WHERE X = 1") + checkPipeWhere("TABLE t |> SELECT X, LENGTH(Y) AS Z |> WHERE X + LENGTH(Y) < 4") + checkPipeWhere("TABLE t |> WHERE X = 1 AND Y = 2 |> WHERE X + Y = 3") + checkPipeWhere("VALUES (0), (1) tab(col) |> WHERE col < 1") } } }