From 50a80611398e6b0425c84a69bd49ac847d2e5acd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 26 Nov 2024 19:40:25 +0100 Subject: [PATCH 1/5] fix grammar to not support empty bodies in loops/if/case --- .../spark/sql/catalyst/parser/SqlBaseParser.g4 | 6 +++--- .../spark/sql/catalyst/parser/AstBuilder.scala | 12 +++++------- .../parser/SqlScriptingParserSuite.scala | 16 +++++++++++++++- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 93cf9974e654c..3b499bc7d9ad2 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -48,15 +48,15 @@ compoundOrSingleStatement ; singleCompoundStatement - : BEGIN compoundBody END SEMICOLON? EOF + : BEGIN compoundBody? END SEMICOLON? EOF ; beginEndCompoundBlock - : beginLabel? BEGIN compoundBody END endLabel? + : beginLabel? BEGIN compoundBody? END endLabel? ; compoundBody - : (compoundStatements+=compoundStatement SEMICOLON)* + : (compoundStatements+=compoundStatement SEMICOLON)+ ; compoundStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a3fac7296dcc4..58a87833bbe2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -144,7 +144,8 @@ class AstBuilder extends DataTypeAstBuilder override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = { val labelCtx = new SqlScriptingLabelContext() - visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = true, labelCtx) + Option(ctx.compoundBody()).map(visitCompoundBodyImpl(_, None, allowVarDeclare = true, labelCtx)) + .getOrElse(CompoundBody(Seq.empty, None)) } private def visitCompoundBodyImpl( @@ -191,12 +192,9 @@ class AstBuilder extends DataTypeAstBuilder labelCtx: SqlScriptingLabelContext): CompoundBody = { val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel())) - val body = visitCompoundBodyImpl( - ctx.compoundBody(), - Some(labelText), - allowVarDeclare = true, - labelCtx - ) + val body = Option(ctx.compoundBody()) + .map(visitCompoundBodyImpl(_, Some(labelText), allowVarDeclare = true, labelCtx)) + .getOrElse(CompoundBody(Seq.empty, Some(labelText))) labelCtx.exitLabeledScope(Option(ctx.beginLabel())) body } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index ab647f83b42a4..134f0b5d59eb3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -82,7 +82,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } } - test("empty BEGIN END block") { + test("empty singleCompoundStatement") { val sqlScriptText = """ |BEGIN @@ -91,6 +91,20 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.isEmpty) } + test("empty beginEndCompoundBlock") { + val sqlScriptText = + """ + |BEGIN + | BEGIN + | END; + |END""".stripMargin + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[CompoundBody]) + val innerBody = tree.collection.head.asInstanceOf[CompoundBody] + assert(innerBody.collection.isEmpty) + } + test("multiple ; in row - should fail") { val sqlScriptText = """ From 7a41e6b9242e263eb611c2e850920ecb74083c3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 27 Nov 2024 15:45:41 +0100 Subject: [PATCH 2/5] add tests for empty loops --- .../parser/SqlScriptingParserSuite.scala | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 134f0b5d59eb3..2109109a24c57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -453,6 +453,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(ifStmt.conditions.head.getText == "1=1") } + test("if with empty body") { + val sqlScriptText = + """BEGIN + | IF 1 = 1 THEN + | END IF; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'IF'", "hint" -> "")) + } + test("if else") { val sqlScriptText = """BEGIN @@ -637,6 +652,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(whileStmt.label.contains("lbl")) } + test("while with empty body") { + val sqlScriptText = + """BEGIN + | WHILE 1 = 1 DO + | END WHILE lbl; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'WHILE'", "hint" -> "")) + } + test("while with complex condition") { val sqlScriptText = """ @@ -1081,6 +1111,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(repeatStmt.label.contains("lbl")) } + test("repeat with empty body") { + val sqlScriptText = + """BEGIN + | REPEAT UNTIL 1 = 1 + | END REPEAT; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'1'", "hint" -> "")) + } + test("repeat with complex condition") { val sqlScriptText = """ @@ -1211,6 +1256,22 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(caseStmt.conditions.head.getText == "1 = 1") } + test("searched case statement with empty body") { + val sqlScriptText = + """BEGIN + | CASE + | WHEN 1 = 1 THEN + | END CASE; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'CASE'", "hint" -> "")) + } + test("searched case statement - multi when") { val sqlScriptText = """ @@ -1349,6 +1410,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) } + test("simple case statement with empty body") { + val sqlScriptText = + """BEGIN + | CASE 1 + | WHEN 1 THEN + | END CASE; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'CASE'", "hint" -> "")) + } test("simple case statement - multi when") { val sqlScriptText = @@ -1496,6 +1572,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(whileStmt.label.contains("lbl")) } + test("loop with empty body") { + val sqlScriptText = + """BEGIN + | LOOP + | END LOOP; + |END + """.stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'LOOP'", "hint" -> "")) + } + test("loop with if else block") { val sqlScriptText = """BEGIN From 729e7359313c47f4ec311318a1d3f2bc86971dce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 27 Nov 2024 16:15:20 +0100 Subject: [PATCH 3/5] format --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 58a87833bbe2e..ac5d03e628acd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -144,7 +144,8 @@ class AstBuilder extends DataTypeAstBuilder override def visitSingleCompoundStatement(ctx: SingleCompoundStatementContext): CompoundBody = { val labelCtx = new SqlScriptingLabelContext() - Option(ctx.compoundBody()).map(visitCompoundBodyImpl(_, None, allowVarDeclare = true, labelCtx)) + Option(ctx.compoundBody()) + .map(visitCompoundBodyImpl(_, None, allowVarDeclare = true, labelCtx)) .getOrElse(CompoundBody(Seq.empty, None)) } From 2260faabf31e4916e89247396983a00b096ccbef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 27 Nov 2024 17:17:06 +0100 Subject: [PATCH 4/5] fix parser test --- .../spark/sql/catalyst/parser/SqlScriptingParserSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 2109109a24c57..690577de3868d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -656,7 +656,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val sqlScriptText = """BEGIN | WHILE 1 = 1 DO - | END WHILE lbl; + | END WHILE; |END """.stripMargin checkError( From 1cafcb5d5b6ec122d483287195bf1c982892f437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 4 Dec 2024 22:48:18 +0100 Subject: [PATCH 5/5] add tests for for statement --- .../parser/SqlScriptingParserSuite.scala | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 690577de3868d..c9e2f42e164f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -2065,6 +2065,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } + test("for statement - empty body") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR x AS SELECT 5 DO + | END FOR; + |END""".stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'FOR'", "hint" -> "")) + } + test("for statement - no label") { val sqlScriptText = """ @@ -2181,6 +2196,21 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } + test("for statement - no variable - empty body") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR SELECT 5 DO + | END FOR; + |END""".stripMargin + checkError( + exception = intercept[ParseException] { + parsePlan(sqlScriptText) + }, + condition = "PARSE_SYNTAX_ERROR", + parameters = Map("error" -> "'FOR'", "hint" -> "")) + } + test("for statement - no variable - no label") { val sqlScriptText = """