From 42663a63e77bda5b763acec82c1eda865b2df2fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 1 Nov 2024 16:44:26 +0100 Subject: [PATCH 01/39] for loop initial version --- .../sql/catalyst/parser/SqlBaseParser.g4 | 5 + .../sql/catalyst/parser/AstBuilder.scala | 14 +++ .../logical/SqlScriptingLogicalPlans.scala | 15 +++ .../parser/SqlScriptingParserSuite.scala | 22 +++++ .../scripting/SqlScriptingExecutionNode.scala | 99 ++++++++++++++++++- .../scripting/SqlScriptingInterpreter.scala | 8 +- .../SqlScriptingInterpreterSuite.scala | 24 +++++ 7 files changed, 185 insertions(+), 2 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index cdee8c906054d..29e481493ff7c 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -70,6 +70,7 @@ compoundStatement | leaveStatement | iterateStatement | loopStatement + | forStatement ; setStatementWithOptionalVarKeyword @@ -111,6 +112,10 @@ loopStatement : beginLabel? LOOP compoundBody END LOOP endLabel? ; +forStatement + : beginLabel? FOR (multipartIdentifier AS)? query DO compoundBody END FOR endLabel? + ; + singleStatement : (statement|setResetStatement) SEMICOLON* EOF ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 08a8cf6bab87a..528fdb12a2628 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -347,6 +347,16 @@ class AstBuilder extends DataTypeAstBuilder RepeatStatement(condition, body, Some(labelText)) } + override def visitForStatement(ctx: ForStatementContext): ForStatement = { + val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel())) + + val query = SingleStatement(visitQuery(ctx.query())) + val identifier = Option(ctx.multipartIdentifier()).map(_.getText) + val body = visitCompoundBody(ctx.compoundBody()) + + ForStatement(query, identifier, body, Some(labelText)) + } + private def leaveOrIterateContextHasLabel( ctx: RuleContext, label: String, isIterate: Boolean): Boolean = { ctx match { @@ -369,6 +379,10 @@ class AstBuilder extends DataTypeAstBuilder if Option(c.beginLabel()).isDefined && c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) => true + case c: ForStatementContext + if Option(c.beginLabel()).isDefined && + c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + => true case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index e6018e5e57b9c..fd511a119b078 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -267,3 +267,18 @@ case class LoopStatement( LoopStatement(newChildren(0).asInstanceOf[CompoundBody], label) } } + +/** + * Logical operator for REPEAT statement. + * @param body Compound body is a collection of statements that are executed once no matter what, + * and then as long as condition is false. + * @param label An optional label for the loop which is unique amongst all labels for statements + * within which the REPEAT statement is contained. + * If an end label is specified it must match the beginning label. + * The label can be used to LEAVE or ITERATE the loop. + */ +case class ForStatement( + query: SingleStatement, + identifier: Option[String], + body: CompoundBody, + label: Option[String]) extends CompoundPlanStatement \ No newline at end of file diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 3bb84f603dc67..7addedab9dffb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -40,6 +40,28 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { } // Tests + test("testtest") { + val sqlScriptText = + """SELECT named_struct('a', 1, 'b', 2, 'c', 3); """.stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + } + + test("initial for") { + val sqlScriptText = + """ + |BEGIN + | FOR x AS (SELECT 1) DO + | SELECT 1; + | SELECT 2; + | END FOR; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + } + test("single select") { val sqlScriptText = "SELECT 1;" val statement = parsePlan(sqlScriptText) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 9129fc6ab00f3..7009529fdacc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -19,9 +19,13 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression, Literal} +import org.apache.spark.sql.catalyst.parser.SingleStatement import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType @@ -636,3 +640,96 @@ class LoopStatementExec( body.reset() } } + +/** + * Executable node for WhileStatement. + * @param condition Executable node for the condition. + * @param body Executable node for the body. + * @param label Label set to WhileStatement by user or None otherwise. + * @param session Spark session that SQL script is executed within. + */ +class ForStatementExec( + query: SingleStatementExec, + identifier: Option[String], + body: CompoundBodyExec, + label: Option[String], + session: SparkSession) extends NonLeafStatementExec { + + private val queryDataframe = { + query.isExecuted = true + Dataset.ofRows(session, query.parsedPlan) + } + private lazy val queryResult = queryDataframe.collect() + + private object ForState extends Enumeration { + val VariableAssignment, Body = Value + } + private var state = ForState.VariableAssignment + private var currRow = 0 + + /** + * Loop can be interrupted by LeaveStatementExec + */ + private var interrupted: Boolean = false + + private lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = !interrupted && currRow < queryResult.length + + override def next(): CompoundStatementExec = state match { + case ForState.VariableAssignment => + + val namedStructArgs: Array[Expression] = queryDataframe.schema.names.flatMap { colName => + List(UnresolvedAttribute(colName), Literal(queryResult(0).getAs(colName))) + } + val namedStruct = Project( + Seq(Alias(CreateNamedStruct(namedStructArgs), identifier.get)()), + OneRowRelation()) + + val setIdentifierToCurrentRow = + SetVariable(Seq(UnresolvedAttribute(identifier.get)), namedStruct) + + val setExec = new SingleStatementExec(setIdentifierToCurrentRow, Origin(), false) + + state = ForState.Body + currRow += 1 + body.reset() + + setExec + case ForState.Body => + val retStmt = body.getTreeIterator.next() + + // Handle LEAVE or ITERATE statement if it has been encountered. + retStmt match { + case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched => + if (label.contains(leaveStatementExec.label)) { + leaveStatementExec.hasBeenMatched = true + } + interrupted = true + return retStmt + case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => + if (label.contains(iterStatementExec.label)) { + iterStatementExec.hasBeenMatched = true + } + state = ForState.VariableAssignment + return retStmt + case _ => + } + + if (!body.getTreeIterator.hasNext) { + state = ForState.VariableAssignment + } + retStmt + } + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + state = ForState.VariableAssignment + body.reset() + } +} + +// val attributes = DataTypeUtils.toAttributes(queryResult.head.schema) +// LocalRelation.fromExternalRows(attributes, Seq(queryResult(0))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 1be75cb61c8b0..20efb25158b05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, CreateVariable, DropVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LogicalPlan, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.trees.Origin /** @@ -123,6 +123,12 @@ case class SqlScriptingInterpreter() { val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] new LoopStatementExec(bodyExec, label) + case ForStatement(query, identifier, body, label) => + val queryExec = new SingleStatementExec(query.parsedPlan, query.origin, isInternal = false) + val bodyExec = + transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] + new ForStatementExec(queryExec, identifier, bodyExec, label, session) + case leaveStatement: LeaveStatement => new LeaveStatementExec(leaveStatement.label) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index b0b844d2b52ca..e1910aa8bce2d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -60,6 +60,30 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { result.zip(expected).foreach { case (df, expectedAnswer) => checkAnswer(df, expectedAnswer) } } + + test("for test") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; + | INSERT INTO t VALUES (1, 'a', 1.0); + | INSERT INTO t VALUES (2, 'b', 2.0); + | FOR x AS SELECT * FROM t DO + | SELECT x; + | END FOR; + |END + |""".stripMargin + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // select with filter + Seq(Row(1)) // select + ) + verifySqlScriptResult(sqlScript, expected) + } + } + // Tests test("multi statement - simple") { withTable("t") { From e7492cfbc35ae5baecc6a42a08b6df007e83e9b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Mon, 4 Nov 2024 15:56:17 +0100 Subject: [PATCH 02/39] first time working --- .../parser/SqlScriptingParserSuite.scala | 2 +- .../scripting/SqlScriptingExecutionNode.scala | 45 ++++++++++++------- .../SqlScriptingInterpreterSuite.scala | 25 ++++++++--- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 7addedab9dffb..e0bd7023fd112 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -52,7 +52,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val sqlScriptText = """ |BEGIN - | FOR x AS (SELECT 1) DO + | FOR x AS SELECT 1 DO | SELECT 1; | SELECT 2; | END FOR; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 7009529fdacc1..882cca35450d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression, Literal} -import org.apache.spark.sql.catalyst.parser.SingleStatement import org.apache.spark.sql.{Dataset, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, SetVariable} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} -import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType @@ -654,8 +652,9 @@ class ForStatementExec( body: CompoundBodyExec, label: Option[String], session: SparkSession) extends NonLeafStatementExec { + // fali reset, drop variable, izvlaciti metode - private val queryDataframe = { + private lazy val queryDataframe = { query.isExecuted = true Dataset.ofRows(session, query.parsedPlan) } @@ -667,6 +666,8 @@ class ForStatementExec( private var state = ForState.VariableAssignment private var currRow = 0 + private var isVariableDeclared = false + /** * Loop can be interrupted by LeaveStatementExec */ @@ -678,21 +679,33 @@ class ForStatementExec( override def next(): CompoundStatementExec = state match { case ForState.VariableAssignment => - - val namedStructArgs: Array[Expression] = queryDataframe.schema.names.flatMap { colName => - List(UnresolvedAttribute(colName), Literal(queryResult(0).getAs(colName))) + val namedStructArgs: Seq[Expression] = + queryDataframe.schema.names.toSeq.flatMap { colName => + Seq(Literal(colName), Literal(queryResult(currRow).getAs(colName))) + } + val namedStruct = CreateNamedStruct(namedStructArgs) + + if (!isVariableDeclared) { + isVariableDeclared = true + val defaultExpression = + DefaultValueExpression(Literal(null, namedStruct.dataType), "null") + val declareVariable = CreateVariable( + UnresolvedIdentifier(Seq(identifier.get)), + defaultExpression, + replace = true + ) + return new SingleStatementExec(declareVariable, Origin(), false) } - val namedStruct = Project( - Seq(Alias(CreateNamedStruct(namedStructArgs), identifier.get)()), - OneRowRelation()) + val projectNamedStruct = Project( + Seq(Alias(namedStruct, identifier.get)()), + OneRowRelation() + ) val setIdentifierToCurrentRow = - SetVariable(Seq(UnresolvedAttribute(identifier.get)), namedStruct) - + SetVariable(Seq(UnresolvedAttribute(identifier.get)), projectNamedStruct) val setExec = new SingleStatementExec(setIdentifierToCurrentRow, Origin(), false) state = ForState.Body - currRow += 1 body.reset() setExec @@ -711,12 +724,14 @@ class ForStatementExec( if (label.contains(iterStatementExec.label)) { iterStatementExec.hasBeenMatched = true } + currRow += 1 state = ForState.VariableAssignment return retStmt case _ => } if (!body.getTreeIterator.hasNext) { + currRow += 1 state = ForState.VariableAssignment } retStmt diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index e1910aa8bce2d..e7a37eb9faafe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -66,19 +66,30 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val sqlScript = """ |BEGIN - | CREATE TABLE t (a INT, b STRING, c DOUBLE) USING parquet; - | INSERT INTO t VALUES (1, 'a', 1.0); - | INSERT INTO t VALUES (2, 'b', 2.0); - | FOR x AS SELECT * FROM t DO - | SELECT x; + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | INSERT INTO t VALUES (2, 'second', 2.0); + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | SELECT x.stringCol; + | SELECT x.doubleCol; | END FOR; |END |""".stripMargin + val expected = Seq( Seq.empty[Row], // create table Seq.empty[Row], // insert - Seq.empty[Row], // select with filter - Seq(Row(1)) // select + Seq.empty[Row], // insert + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row(1)), // select intCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select doubleCol + Seq.empty[Row], // set x to row 1 + Seq(Row(2)), // select intCol + Seq(Row("second")), // select stringCol + Seq(Row(2.0)) // select doubleCol ) verifySqlScriptResult(sqlScript, expected) } From 63289c78a697b4265fb2663e0623041b12e33556 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Mon, 4 Nov 2024 17:50:00 +0100 Subject: [PATCH 03/39] drop local variable at end of execution --- .../scripting/SqlScriptingExecutionNode.scala | 22 ++++++++++++++++--- .../SqlScriptingInterpreterSuite.scala | 3 ++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 882cca35450d5..b4cff5941450e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType @@ -652,7 +652,7 @@ class ForStatementExec( body: CompoundBodyExec, label: Option[String], session: SparkSession) extends NonLeafStatementExec { - // fali reset, drop variable, izvlaciti metode + // fali reset, drop variable, extract methods, case when identifier is None private lazy val queryDataframe = { query.isExecuted = true @@ -675,10 +675,24 @@ class ForStatementExec( private lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { - override def hasNext: Boolean = !interrupted && currRow < queryResult.length + override def hasNext: Boolean = + !interrupted && + queryResult.length > 0 && + // not currRow < queryResult.length because when + // currRow == queryResult.length we drop the local variable + currRow <= queryResult.length override def next(): CompoundStatementExec = state match { case ForState.VariableAssignment => + // after all rows in the result set have been iterated, the local variable is dropped + if (currRow == queryResult.length) { + // set currRow to queryResult.length + 1 to end execution after current .next call + currRow += 1 + val dropVariable = + DropVariable(UnresolvedIdentifier(Seq(identifier.get)), ifExists = true) + return new SingleStatementExec(dropVariable, Origin(), false) + } + val namedStructArgs: Seq[Expression] = queryDataframe.schema.names.toSeq.flatMap { colName => Seq(Literal(colName), Literal(queryResult(currRow).getAs(colName))) @@ -709,6 +723,7 @@ class ForStatementExec( body.reset() setExec + case ForState.Body => val retStmt = body.getTreeIterator.next() @@ -742,6 +757,7 @@ class ForStatementExec( override def reset(): Unit = { state = ForState.VariableAssignment + currRow = 0 body.reset() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index e7a37eb9faafe..d3ffc99fd1c9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -89,7 +89,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // set x to row 1 Seq(Row(2)), // select intCol Seq(Row("second")), // select stringCol - Seq(Row(2.0)) // select doubleCol + Seq(Row(2.0)), // select doubleCol + Seq.empty[Row] // drop x ) verifySqlScriptResult(sqlScript, expected) } From 7de3e7605afa5599f211ff4ab1b65640a6342dad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Mon, 4 Nov 2024 19:25:59 +0100 Subject: [PATCH 04/39] refactor to drop var after every iteration --- .../scripting/SqlScriptingExecutionNode.scala | 57 +++++++++---------- .../scripting/SqlScriptingInterpreter.scala | 13 ++++- .../SqlScriptingInterpreterSuite.scala | 2 + 3 files changed, 42 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index b4cff5941450e..04f5b38747f3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType @@ -652,7 +652,7 @@ class ForStatementExec( body: CompoundBodyExec, label: Option[String], session: SparkSession) extends NonLeafStatementExec { - // fali reset, drop variable, extract methods, case when identifier is None + // fali reset, extract methods, case when identifier is None private lazy val queryDataframe = { query.isExecuted = true @@ -661,9 +661,9 @@ class ForStatementExec( private lazy val queryResult = queryDataframe.collect() private object ForState extends Enumeration { - val VariableAssignment, Body = Value + val VariableDeclaration, VariableAssignment, Body = Value } - private var state = ForState.VariableAssignment + private var state = ForState.VariableDeclaration private var currRow = 0 private var isVariableDeclared = false @@ -680,36 +680,34 @@ class ForStatementExec( queryResult.length > 0 && // not currRow < queryResult.length because when // currRow == queryResult.length we drop the local variable - currRow <= queryResult.length + currRow < queryResult.length override def next(): CompoundStatementExec = state match { - case ForState.VariableAssignment => - // after all rows in the result set have been iterated, the local variable is dropped - if (currRow == queryResult.length) { - // set currRow to queryResult.length + 1 to end execution after current .next call - currRow += 1 - val dropVariable = - DropVariable(UnresolvedIdentifier(Seq(identifier.get)), ifExists = true) - return new SingleStatementExec(dropVariable, Origin(), false) - } - + case ForState.VariableDeclaration => val namedStructArgs: Seq[Expression] = queryDataframe.schema.names.toSeq.flatMap { colName => Seq(Literal(colName), Literal(queryResult(currRow).getAs(colName))) } val namedStruct = CreateNamedStruct(namedStructArgs) - if (!isVariableDeclared) { - isVariableDeclared = true - val defaultExpression = - DefaultValueExpression(Literal(null, namedStruct.dataType), "null") - val declareVariable = CreateVariable( - UnresolvedIdentifier(Seq(identifier.get)), - defaultExpression, - replace = true - ) - return new SingleStatementExec(declareVariable, Origin(), false) - } + val defaultExpression = + DefaultValueExpression(Literal(null, namedStruct.dataType), "null") + val declareVariable = CreateVariable( + UnresolvedIdentifier(Seq(identifier.get)), + defaultExpression, + replace = true + ) + val declareExec = new SingleStatementExec(declareVariable, Origin(), false) + + state = ForState.VariableAssignment + + declareExec + case ForState.VariableAssignment => + val namedStructArgs: Seq[Expression] = + queryDataframe.schema.names.toSeq.flatMap { colName => + Seq(Literal(colName), Literal(queryResult(currRow).getAs(colName))) + } + val namedStruct = CreateNamedStruct(namedStructArgs) val projectNamedStruct = Project( Seq(Alias(namedStruct, identifier.get)()), @@ -740,14 +738,14 @@ class ForStatementExec( iterStatementExec.hasBeenMatched = true } currRow += 1 - state = ForState.VariableAssignment + state = ForState.VariableDeclaration return retStmt case _ => } if (!body.getTreeIterator.hasNext) { currRow += 1 - state = ForState.VariableAssignment + state = ForState.VariableDeclaration } retStmt } @@ -756,8 +754,9 @@ class ForStatementExec( override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator override def reset(): Unit = { - state = ForState.VariableAssignment + state = ForState.VariableDeclaration currRow = 0 + isVariableDeclared = false body.reset() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 20efb25158b05..b1e3532217f96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -127,7 +127,18 @@ case class SqlScriptingInterpreter() { val queryExec = new SingleStatementExec(query.parsedPlan, query.origin, isInternal = false) val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] - new ForStatementExec(queryExec, identifier, bodyExec, label, session) + val dropVariableExec = new SingleStatementExec( + DropVariable(UnresolvedIdentifier(Seq(identifier.get)), ifExists = true), + Origin(), + isInternal = true) + + new ForStatementExec( + queryExec, + identifier, + new CompoundBodyExec(Seq(bodyExec, dropVariableExec)), + label, + session + ) case leaveStatement: LeaveStatement => new LeaveStatementExec(leaveStatement.label) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index d3ffc99fd1c9c..83a71e258fa52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -86,6 +86,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1)), // select intCol Seq(Row("first")), // select stringCol Seq(Row(1.0)), // select doubleCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x Seq.empty[Row], // set x to row 1 Seq(Row(2)), // select intCol Seq(Row("second")), // select stringCol From 91fb9ee4fc01b10ae9aa248b87a70a23046bcfb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 5 Nov 2024 10:51:51 +0100 Subject: [PATCH 05/39] cleanup code --- .../scripting/SqlScriptingExecutionNode.scala | 63 +++++++++---------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 04f5b38747f3a..6df0dd96ec675 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -665,8 +665,7 @@ class ForStatementExec( } private var state = ForState.VariableDeclaration private var currRow = 0 - - private var isVariableDeclared = false + private var currNamedStruct: CreateNamedStruct = null /** * Loop can be interrupted by LeaveStatementExec @@ -676,11 +675,7 @@ class ForStatementExec( private lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { override def hasNext: Boolean = - !interrupted && - queryResult.length > 0 && - // not currRow < queryResult.length because when - // currRow == queryResult.length we drop the local variable - currRow < queryResult.length + !interrupted && queryResult.length > 0 && currRow < queryResult.length override def next(): CompoundStatementExec = state match { case ForState.VariableDeclaration => @@ -688,39 +683,15 @@ class ForStatementExec( queryDataframe.schema.names.toSeq.flatMap { colName => Seq(Literal(colName), Literal(queryResult(currRow).getAs(colName))) } - val namedStruct = CreateNamedStruct(namedStructArgs) - - val defaultExpression = - DefaultValueExpression(Literal(null, namedStruct.dataType), "null") - val declareVariable = CreateVariable( - UnresolvedIdentifier(Seq(identifier.get)), - defaultExpression, - replace = true - ) - val declareExec = new SingleStatementExec(declareVariable, Origin(), false) + currNamedStruct = CreateNamedStruct(namedStructArgs) state = ForState.VariableAssignment + createDeclareVarExec(currNamedStruct) - declareExec case ForState.VariableAssignment => - val namedStructArgs: Seq[Expression] = - queryDataframe.schema.names.toSeq.flatMap { colName => - Seq(Literal(colName), Literal(queryResult(currRow).getAs(colName))) - } - val namedStruct = CreateNamedStruct(namedStructArgs) - - val projectNamedStruct = Project( - Seq(Alias(namedStruct, identifier.get)()), - OneRowRelation() - ) - val setIdentifierToCurrentRow = - SetVariable(Seq(UnresolvedAttribute(identifier.get)), projectNamedStruct) - val setExec = new SingleStatementExec(setIdentifierToCurrentRow, Origin(), false) - state = ForState.Body body.reset() - - setExec + createSetVarExec(currNamedStruct) case ForState.Body => val retStmt = body.getTreeIterator.next() @@ -751,12 +722,34 @@ class ForStatementExec( } } + private def createDeclareVarExec(namedStruct: CreateNamedStruct): SingleStatementExec = { + val defaultExpression = DefaultValueExpression(Literal(null, namedStruct.dataType), "null") + val declareVariable = CreateVariable( + UnresolvedIdentifier(Seq(identifier.get)), + defaultExpression, + replace = true + ) + val declareExec = new SingleStatementExec(declareVariable, Origin(), false) + declareExec + } + + private def createSetVarExec(namedStruct: CreateNamedStruct): SingleStatementExec = { + val projectNamedStruct = Project( + Seq(Alias(namedStruct, identifier.get)()), + OneRowRelation() + ) + val setIdentifierToCurrentRow = + SetVariable(Seq(UnresolvedAttribute(identifier.get)), projectNamedStruct) + val setExec = new SingleStatementExec(setIdentifierToCurrentRow, Origin(), false) + setExec + } + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator override def reset(): Unit = { + // TODO: run query again state = ForState.VariableDeclaration currRow = 0 - isVariableDeclared = false body.reset() } } From 5154a4897db4a69fdb88b7cfa9c87aae38290f7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 5 Nov 2024 12:20:46 +0100 Subject: [PATCH 06/39] support for FOR without variable --- .../scripting/SqlScriptingExecutionNode.scala | 70 ++++++++++++------- .../scripting/SqlScriptingInterpreter.scala | 25 ++++--- .../SqlScriptingInterpreterSuite.scala | 25 +++++++ 3 files changed, 82 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 6df0dd96ec675..f25dd4be81c7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} @@ -652,20 +652,34 @@ class ForStatementExec( body: CompoundBodyExec, label: Option[String], session: SparkSession) extends NonLeafStatementExec { - // fali reset, extract methods, case when identifier is None - - private lazy val queryDataframe = { - query.isExecuted = true - Dataset.ofRows(session, query.parsedPlan) - } - private lazy val queryResult = queryDataframe.collect() - + // fali reset, case when identifier is None private object ForState extends Enumeration { val VariableDeclaration, VariableAssignment, Body = Value } private var state = ForState.VariableDeclaration private var currRow = 0 - private var currNamedStruct: CreateNamedStruct = null + private var currVariable: CreateNamedStruct = null + + private var queryDataframe: DataFrame = null + private var isDataframeCacheValid = false + private def cachedQueryDataframe(): DataFrame = { + if (!isDataframeCacheValid) { + query.isExecuted = true + queryDataframe = Dataset.ofRows(session, query.parsedPlan) + isDataframeCacheValid = true + } + queryDataframe + } + + private var queryResult: Array[Row] = null + private var isResultCacheValid = false + private def cachedQueryResult(): Array[Row] = { + if (!isResultCacheValid) { + queryResult = cachedQueryDataframe().collect() + isResultCacheValid = true + } + queryResult + } /** * Loop can be interrupted by LeaveStatementExec @@ -675,23 +689,29 @@ class ForStatementExec( private lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { override def hasNext: Boolean = - !interrupted && queryResult.length > 0 && currRow < queryResult.length + !interrupted && cachedQueryResult().length > 0 && currRow < cachedQueryResult().length override def next(): CompoundStatementExec = state match { case ForState.VariableDeclaration => + if (identifier.isEmpty) { + state = ForState.Body + body.reset() + return next() + } + val namedStructArgs: Seq[Expression] = - queryDataframe.schema.names.toSeq.flatMap { colName => - Seq(Literal(colName), Literal(queryResult(currRow).getAs(colName))) + cachedQueryDataframe().schema.names.toSeq.flatMap { colName => + Seq(Literal(colName), Literal(cachedQueryResult()(currRow).getAs(colName))) } - currNamedStruct = CreateNamedStruct(namedStructArgs) + currVariable = CreateNamedStruct(namedStructArgs) state = ForState.VariableAssignment - createDeclareVarExec(currNamedStruct) + createDeclareVarExec(currVariable) case ForState.VariableAssignment => state = ForState.Body body.reset() - createSetVarExec(currNamedStruct) + createSetVarExec(currVariable) case ForState.Body => val retStmt = body.getTreeIterator.next() @@ -722,34 +742,34 @@ class ForStatementExec( } } - private def createDeclareVarExec(namedStruct: CreateNamedStruct): SingleStatementExec = { - val defaultExpression = DefaultValueExpression(Literal(null, namedStruct.dataType), "null") + private def createDeclareVarExec(variable: Expression): SingleStatementExec = { + val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null") val declareVariable = CreateVariable( UnresolvedIdentifier(Seq(identifier.get)), defaultExpression, replace = true ) - val declareExec = new SingleStatementExec(declareVariable, Origin(), false) - declareExec + new SingleStatementExec(declareVariable, Origin(), false) } - private def createSetVarExec(namedStruct: CreateNamedStruct): SingleStatementExec = { + private def createSetVarExec(variable: Expression): SingleStatementExec = { val projectNamedStruct = Project( - Seq(Alias(namedStruct, identifier.get)()), + Seq(Alias(variable, identifier.get)()), OneRowRelation() ) val setIdentifierToCurrentRow = SetVariable(Seq(UnresolvedAttribute(identifier.get)), projectNamedStruct) - val setExec = new SingleStatementExec(setIdentifierToCurrentRow, Origin(), false) - setExec + new SingleStatementExec(setIdentifierToCurrentRow, Origin(), false) } override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator override def reset(): Unit = { - // TODO: run query again state = ForState.VariableDeclaration + isDataframeCacheValid = false + isResultCacheValid = false currRow = 0 + currVariable = null body.reset() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index b1e3532217f96..f67255ee210bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -123,22 +123,21 @@ case class SqlScriptingInterpreter() { val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] new LoopStatementExec(bodyExec, label) - case ForStatement(query, identifier, body, label) => + case ForStatement(query, identifierOpt, body, label) => val queryExec = new SingleStatementExec(query.parsedPlan, query.origin, isInternal = false) val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] - val dropVariableExec = new SingleStatementExec( - DropVariable(UnresolvedIdentifier(Seq(identifier.get)), ifExists = true), - Origin(), - isInternal = true) - - new ForStatementExec( - queryExec, - identifier, - new CompoundBodyExec(Seq(bodyExec, dropVariableExec)), - label, - session - ) + val finalExec = identifierOpt match { + case None => bodyExec + case Some(identifier) => + val dropVariableExec = new SingleStatementExec( + DropVariable(UnresolvedIdentifier(Seq(identifier)), ifExists = true), + Origin(), + isInternal = true) + new CompoundBodyExec(Seq(bodyExec, dropVariableExec)) + } + + new ForStatementExec(queryExec, identifierOpt, finalExec, label, session) case leaveStatement: LeaveStatement => new LeaveStatementExec(leaveStatement.label) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 83a71e258fa52..b97dfb6741519 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -98,6 +98,31 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + test("for test no variable") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | INSERT INTO t VALUES (2, 'second', 2.0); + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT 1; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(1)), // select 1 + Seq(Row(1)), // select 1 + ) + verifySqlScriptResult(sqlScript, expected) + } + } + // Tests test("multi statement - simple") { withTable("t") { From 4ca1ed2c01692d6ac7289818e7c6f8a5b69f4cf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 5 Nov 2024 12:43:43 +0100 Subject: [PATCH 07/39] adding tests --- .../scripting/SqlScriptingExecutionNode.scala | 4 +- .../SqlScriptingInterpreterSuite.scala | 88 +++++++++++++++++++ 2 files changed, 90 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index f25dd4be81c7b..f0eccec43502c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -749,7 +749,7 @@ class ForStatementExec( defaultExpression, replace = true ) - new SingleStatementExec(declareVariable, Origin(), false) + new SingleStatementExec(declareVariable, Origin(), isInternal = true) } private def createSetVarExec(variable: Expression): SingleStatementExec = { @@ -759,7 +759,7 @@ class ForStatementExec( ) val setIdentifierToCurrentRow = SetVariable(Seq(UnresolvedAttribute(identifier.get)), projectNamedStruct) - new SingleStatementExec(setIdentifierToCurrentRow, Origin(), false) + new SingleStatementExec(setIdentifierToCurrentRow, Origin(), isInternal = true) } override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index b97dfb6741519..81dad9dbe10f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -98,6 +98,94 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + test("for test iterate") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'); + | INSERT INTO t VALUES (2, 'second'); + | INSERT INTO t VALUES (3, 'third'); + | INSERT INTO t VALUES (4, 'fourth'); + | + | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO + | IF x.intCol = 2 THEN + | ITERATE lbl; + | END IF; + | SELECT x.stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row("first")), // select stringCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 1 +// Seq.empty[Row], // drop x - TODO: uncomment when iterate can handle dropping vars + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 2 + Seq(Row("third")), // select stringCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 3 + Seq(Row("fourth")), // select stringCol + Seq.empty[Row], // drop x + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for test leave") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'); + | INSERT INTO t VALUES (2, 'second'); + | INSERT INTO t VALUES (3, 'third'); + | INSERT INTO t VALUES (4, 'fourth'); + | + | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO + | IF x.intCol = 3 THEN + | LEAVE lbl; + | END IF; + | SELECT x.stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row("first")), // select stringCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 1 + Seq(Row("second")), // select stringCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 2 +// Seq.empty[Row], // drop x + ) + verifySqlScriptResult(sqlScript, expected) + } + } + test("for test no variable") { withTable("t") { val sqlScript = From 130d0d17fccc35f02dd1ef15bfab7436b4b804a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 6 Nov 2024 14:14:18 +0100 Subject: [PATCH 08/39] add more tests --- .../scripting/SqlScriptingExecutionNode.scala | 7 +- .../SqlScriptingInterpreterSuite.scala | 64 +++++++++++++++++++ 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index f0eccec43502c..5c0e3393f925a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -652,7 +652,7 @@ class ForStatementExec( body: CompoundBodyExec, label: Option[String], session: SparkSession) extends NonLeafStatementExec { - // fali reset, case when identifier is None + private object ForState extends Enumeration { val VariableDeclaration, VariableAssignment, Body = Value } @@ -693,12 +693,14 @@ class ForStatementExec( override def next(): CompoundStatementExec = state match { case ForState.VariableDeclaration => + // when there is no for variable, skip var declaration and iterate only the body if (identifier.isEmpty) { state = ForState.Body body.reset() return next() } + // arguments of CreateNamedStruct must be formatted like (name1, val1, name2, val2, ...) val namedStructArgs: Seq[Expression] = cachedQueryDataframe().schema.names.toSeq.flatMap { colName => Seq(Literal(colName), Literal(cachedQueryResult()(currRow).getAs(colName))) @@ -773,6 +775,3 @@ class ForStatementExec( body.reset() } } - -// val attributes = DataTypeUtils.toAttributes(queryResult.head.schema) -// LocalRelation.fromExternalRows(attributes, Seq(queryResult(0))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 81dad9dbe10f6..31f35a7e419b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -61,6 +61,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } + // todo: complex types in for test("for test") { withTable("t") { val sqlScript = @@ -98,6 +99,25 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + test("for test empty result") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + ) + verifySqlScriptResult(sqlScript, expected) + } + } + test("for test iterate") { withTable("t") { val sqlScript = @@ -186,6 +206,50 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + test("for test nested in while") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE i = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (0); + | WHILE i < 2 DO + | SET i = i + 1; + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | END FOR; + | INSERT INTO t VALUES (i); + | END WHILE; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare i + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set i + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row(0)), // select intCol + Seq.empty[Row], // drop x + Seq.empty[Row], // insert + Seq.empty[Row], // set i + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row(0)), // select intCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 1 + Seq(Row(1)), // select intCol + Seq.empty[Row], // drop x + Seq.empty[Row], // insert + Seq.empty[Row], // drop i + ) + verifySqlScriptResult(sqlScript, expected) + } + } + test("for test no variable") { withTable("t") { val sqlScript = From 3335e227a9880d515d111d26454955d05fac7d97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 7 Nov 2024 16:01:47 +0100 Subject: [PATCH 09/39] add support for map --- .../parser/SqlScriptingParserSuite.scala | 4 +- .../scripting/SqlScriptingExecutionNode.scala | 17 +++- .../SqlScriptingInterpreterSuite.scala | 84 ++++++++++++++++++- 3 files changed, 100 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index e0bd7023fd112..c8cce87f35c1f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -41,8 +41,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { // Tests test("testtest") { - val sqlScriptText = - """SELECT named_struct('a', 1, 'b', 2, 'c', 3); """.stripMargin + val sqlScriptText = "DECLARE my_map DEFAULT MAP('x', 0, 'y', 0);" + val tree = parseScript(sqlScriptText) assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ForStatement]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 5c0e3393f925a..24f25e48aba15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.scripting +import scala.collection.immutable.Map import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType + /** * Trait for all SQL scripting execution nodes used during interpretation phase. */ @@ -703,7 +705,9 @@ class ForStatementExec( // arguments of CreateNamedStruct must be formatted like (name1, val1, name2, val2, ...) val namedStructArgs: Seq[Expression] = cachedQueryDataframe().schema.names.toSeq.flatMap { colName => - Seq(Literal(colName), Literal(cachedQueryResult()(currRow).getAs(colName))) + val valueExpression = + createExpressionFromValue(cachedQueryResult()(currRow).getAs(colName)) + Seq(Literal(colName), valueExpression) } currVariable = CreateNamedStruct(namedStructArgs) @@ -744,6 +748,15 @@ class ForStatementExec( } } + private def createExpressionFromValue(value: Any): Expression = value match { + case m: Map[_, _] => + val mapArgs = m.keys.zip(m.values).flatMap {case (k, v) => + Seq(createExpressionFromValue(k), createExpressionFromValue(v)) + }.toSeq + CreateMap(mapArgs, false) + case _ => Literal(value) + } + private def createDeclareVarExec(variable: Expression): SingleStatementExec = { val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null") val declareVariable = CreateVariable( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 31f35a7e419b9..d4e95302246a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -60,8 +60,13 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { result.zip(expected).foreach { case (df, expectedAnswer) => checkAnswer(df, expectedAnswer) } } + test("testetst") { + val sqlScript = "DECLARE my_map DEFAULT MAP('x', 0, 'y', 0);" + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + - // todo: complex types in for + // todo: complex types in for, better var names in tests test("for test") { withTable("t") { val sqlScript = @@ -99,6 +104,83 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + test("for test complex types") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP>, struct_column STRUCT, array_column ARRAY); + | INSERT INTO t VALUES + | (1, MAP('a', MAP('1', 1)), STRUCT('John', 25), ARRAY('apple', 'banana')), + | (1, MAP('b', MAP('2', 2)), STRUCT('Jane', 30), ARRAY('apple', 'banana')); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.map_column; + | SELECT row.struct_column; + | SELECT row.array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row(1)), // select intCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select doubleCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 1 + Seq(Row(2)), // select intCol + Seq(Row("second")), // select stringCol + Seq(Row(2.0)), // select doubleCol + Seq.empty[Row] // drop x + ) + verifySqlScriptResult(sqlScript, expected) + } + } + +// test("for test complex types") { +// withTable("t") { +// val sqlScript = +// """ +// |BEGIN +// | CREATE TABLE t (int_column INT, map_column MAP, struct_column STRUCT, array_column ARRAY); +// | INSERT INTO t VALUES +// | (1, MAP('a', 1, 'b', 2), STRUCT('John', 25), ARRAY('apple', 'banana')), +// | (2, MAP('c', 3, 'd', 4), STRUCT('Jane', 30), ARRAY('orange', 'grape')), +// | (3, MAP('e', 5, 'f', 6), STRUCT('Bob', 35), ARRAY('pear', 'peach')); +// | FOR row AS SELECT * FROM t ORDER BY int_column DO +// | SELECT row.map_column; +// | SELECT row.struct_column; +// | SELECT row.array_column; +// | END FOR; +// |END +// |""".stripMargin +// +// val expected = Seq( +// Seq.empty[Row], // create table +// Seq.empty[Row], // insert +// Seq.empty[Row], // insert +// Seq.empty[Row], // declare x +// Seq.empty[Row], // set x to row 0 +// Seq(Row(1)), // select intCol +// Seq(Row("first")), // select stringCol +// Seq(Row(1.0)), // select doubleCol +// Seq.empty[Row], // drop x +// Seq.empty[Row], // declare x +// Seq.empty[Row], // set x to row 1 +// Seq(Row(2)), // select intCol +// Seq(Row("second")), // select stringCol +// Seq(Row(2.0)), // select doubleCol +// Seq.empty[Row] // drop x +// ) +// verifySqlScriptResult(sqlScript, expected) +// } +// } + test("for test empty result") { withTable("t") { val sqlScript = From 85f99e92f6229a3864122dcc18dcf70c3631767e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 7 Nov 2024 16:21:48 +0100 Subject: [PATCH 10/39] add support for struct, seems to work --- .../parser/SqlScriptingParserSuite.scala | 8 ++++++++ .../scripting/SqlScriptingExecutionNode.scala | 20 ++++++++++--------- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index c8cce87f35c1f..de4517120ce84 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -48,6 +48,14 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection.head.isInstanceOf[ForStatement]) } + test("testtesttest") { + val sqlScriptText = "DECLARE my_struct DEFAULT STRUCT<'x', 0, 1.2>;" + + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + } + test("initial for") { val sqlScriptText = """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 24f25e48aba15..63337fa07405c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, GenericRowWithSchema, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors @@ -660,10 +660,11 @@ class ForStatementExec( } private var state = ForState.VariableDeclaration private var currRow = 0 - private var currVariable: CreateNamedStruct = null + private var currVariable: Expression = null private var queryDataframe: DataFrame = null private var isDataframeCacheValid = false + // todo: see if you need this private def cachedQueryDataframe(): DataFrame = { if (!isDataframeCacheValid) { query.isExecuted = true @@ -703,13 +704,7 @@ class ForStatementExec( } // arguments of CreateNamedStruct must be formatted like (name1, val1, name2, val2, ...) - val namedStructArgs: Seq[Expression] = - cachedQueryDataframe().schema.names.toSeq.flatMap { colName => - val valueExpression = - createExpressionFromValue(cachedQueryResult()(currRow).getAs(colName)) - Seq(Literal(colName), valueExpression) - } - currVariable = CreateNamedStruct(namedStructArgs) + currVariable = createExpressionFromValue(cachedQueryResult()(currRow)) state = ForState.VariableAssignment createDeclareVarExec(currVariable) @@ -754,6 +749,13 @@ class ForStatementExec( Seq(createExpressionFromValue(k), createExpressionFromValue(v)) }.toSeq CreateMap(mapArgs, false) + case s: GenericRowWithSchema => + val namedStructArgs = s.schema.names.toSeq.flatMap { colName => + val valueExpression = + createExpressionFromValue(s.getAs(colName)) + Seq(Literal(colName), valueExpression) + } + CreateNamedStruct(namedStructArgs) case _ => Literal(value) } From f1e126892e619bdb601a437c6311888b79052318 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 7 Nov 2024 16:30:51 +0100 Subject: [PATCH 11/39] clean up --- .../scripting/SqlScriptingExecutionNode.scala | 26 +++++-------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 63337fa07405c..1c08f000c2b55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.scripting -import scala.collection.immutable.Map import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, GenericRowWithSchema, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} @@ -662,23 +661,12 @@ class ForStatementExec( private var currRow = 0 private var currVariable: Expression = null - private var queryDataframe: DataFrame = null - private var isDataframeCacheValid = false - // todo: see if you need this - private def cachedQueryDataframe(): DataFrame = { - if (!isDataframeCacheValid) { - query.isExecuted = true - queryDataframe = Dataset.ofRows(session, query.parsedPlan) - isDataframeCacheValid = true - } - queryDataframe - } - private var queryResult: Array[Row] = null private var isResultCacheValid = false private def cachedQueryResult(): Array[Row] = { if (!isResultCacheValid) { - queryResult = cachedQueryDataframe().collect() + query.isExecuted = true + queryResult = Dataset.ofRows(session, query.parsedPlan).collect() isResultCacheValid = true } queryResult @@ -702,8 +690,6 @@ class ForStatementExec( body.reset() return next() } - - // arguments of CreateNamedStruct must be formatted like (name1, val1, name2, val2, ...) currVariable = createExpressionFromValue(cachedQueryResult()(currRow)) state = ForState.VariableAssignment @@ -745,12 +731,15 @@ class ForStatementExec( private def createExpressionFromValue(value: Any): Expression = value match { case m: Map[_, _] => + // arguments of CreateMap are in the format: (key1, val1, key2, val2, ...) val mapArgs = m.keys.zip(m.values).flatMap {case (k, v) => Seq(createExpressionFromValue(k), createExpressionFromValue(v)) }.toSeq CreateMap(mapArgs, false) + // structs enter this case case s: GenericRowWithSchema => - val namedStructArgs = s.schema.names.toSeq.flatMap { colName => + // arguments of CreateNamedStruct are in the format: (name1, val1, name2, val2, ...) + val namedStructArgs = s.schema.names.toSeq.flatMap { colName => val valueExpression = createExpressionFromValue(s.getAs(colName)) Seq(Literal(colName), valueExpression) @@ -783,7 +772,6 @@ class ForStatementExec( override def reset(): Unit = { state = ForState.VariableDeclaration - isDataframeCacheValid = false isResultCacheValid = false currRow = 0 currVariable = null From e41fa94fec93fb9bbc7c1ffbed83b68b17e50b0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 7 Nov 2024 16:48:01 +0100 Subject: [PATCH 12/39] fix comments and clean up code --- .../sql/catalyst/parser/AstBuilder.scala | 4 +-- .../logical/SqlScriptingLogicalPlans.scala | 14 ++++---- .../scripting/SqlScriptingExecutionNode.scala | 33 ++++++++++--------- 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 528fdb12a2628..a721aaa30b36e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -351,10 +351,10 @@ class AstBuilder extends DataTypeAstBuilder val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel())) val query = SingleStatement(visitQuery(ctx.query())) - val identifier = Option(ctx.multipartIdentifier()).map(_.getText) + val varName = Option(ctx.multipartIdentifier()).map(_.getText) val body = visitCompoundBody(ctx.compoundBody()) - ForStatement(query, identifier, body, Some(labelText)) + ForStatement(query, varName, body, Some(labelText)) } private def leaveOrIterateContextHasLabel( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index fd511a119b078..ce8d721b23eeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -269,16 +269,18 @@ case class LoopStatement( } /** - * Logical operator for REPEAT statement. - * @param body Compound body is a collection of statements that are executed once no matter what, - * and then as long as condition is false. + * Logical operator for FOR statement. + * @param query Query which is executed once, then it's result is iterated on, row by row + * @param variableName Name of variable which is used to access the current row during iteration + * @param body Compound body is a collection of statements that are executed once for each row in + * the result set of the query * @param label An optional label for the loop which is unique amongst all labels for statements - * within which the REPEAT statement is contained. + * within which the FOR statement is contained. * If an end label is specified it must match the beginning label. * The label can be used to LEAVE or ITERATE the loop. */ case class ForStatement( query: SingleStatement, - identifier: Option[String], + variableName: Option[String], body: CompoundBody, - label: Option[String]) extends CompoundPlanStatement \ No newline at end of file + label: Option[String]) extends CompoundPlanStatement diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 1c08f000c2b55..0e2d495a5ff89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -641,15 +641,17 @@ class LoopStatementExec( } /** - * Executable node for WhileStatement. - * @param condition Executable node for the condition. - * @param body Executable node for the body. - * @param label Label set to WhileStatement by user or None otherwise. + * Executable node for ForStatement. + * @param query Executable node for the query. + * @param variableName Name of variable used for accessing current row during iteration. + * @param body Executable node for the body. If variableName is not None, will have DropVariable + * as the last statement. + * @param label Label set to ForStatement by user or None otherwise. * @param session Spark session that SQL script is executed within. */ class ForStatementExec( query: SingleStatementExec, - identifier: Option[String], + variableName: Option[String], body: CompoundBodyExec, label: Option[String], session: SparkSession) extends NonLeafStatementExec { @@ -685,7 +687,7 @@ class ForStatementExec( override def next(): CompoundStatementExec = state match { case ForState.VariableDeclaration => // when there is no for variable, skip var declaration and iterate only the body - if (identifier.isEmpty) { + if (variableName.isEmpty) { state = ForState.Body body.reset() return next() @@ -693,12 +695,12 @@ class ForStatementExec( currVariable = createExpressionFromValue(cachedQueryResult()(currRow)) state = ForState.VariableAssignment - createDeclareVarExec(currVariable) + createDeclareVarExec(variableName.get, currVariable) case ForState.VariableAssignment => state = ForState.Body body.reset() - createSetVarExec(currVariable) + createSetVarExec(variableName.get, currVariable) case ForState.Body => val retStmt = body.getTreeIterator.next() @@ -736,35 +738,34 @@ class ForStatementExec( Seq(createExpressionFromValue(k), createExpressionFromValue(v)) }.toSeq CreateMap(mapArgs, false) - // structs enter this case case s: GenericRowWithSchema => + // struct types match this case // arguments of CreateNamedStruct are in the format: (name1, val1, name2, val2, ...) val namedStructArgs = s.schema.names.toSeq.flatMap { colName => - val valueExpression = - createExpressionFromValue(s.getAs(colName)) + val valueExpression = createExpressionFromValue(s.getAs(colName)) Seq(Literal(colName), valueExpression) } CreateNamedStruct(namedStructArgs) case _ => Literal(value) } - private def createDeclareVarExec(variable: Expression): SingleStatementExec = { + private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = { val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null") val declareVariable = CreateVariable( - UnresolvedIdentifier(Seq(identifier.get)), + UnresolvedIdentifier(Seq(varName)), defaultExpression, replace = true ) new SingleStatementExec(declareVariable, Origin(), isInternal = true) } - private def createSetVarExec(variable: Expression): SingleStatementExec = { + private def createSetVarExec(varName: String, variable: Expression): SingleStatementExec = { val projectNamedStruct = Project( - Seq(Alias(variable, identifier.get)()), + Seq(Alias(variable, varName)()), OneRowRelation() ) val setIdentifierToCurrentRow = - SetVariable(Seq(UnresolvedAttribute(identifier.get)), projectNamedStruct) + SetVariable(Seq(UnresolvedAttribute(variableName.get)), projectNamedStruct) new SingleStatementExec(setIdentifierToCurrentRow, Origin(), isInternal = true) } From 57ec14b412553c21e898006d0f054251b12f1abb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 7 Nov 2024 16:52:45 +0100 Subject: [PATCH 13/39] formatting --- .../sql/scripting/SqlScriptingInterpreter.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index f67255ee210bc..1a182a152b787 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -123,21 +123,21 @@ case class SqlScriptingInterpreter() { val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] new LoopStatementExec(bodyExec, label) - case ForStatement(query, identifierOpt, body, label) => + case ForStatement(query, variableNameOpt, body, label) => val queryExec = new SingleStatementExec(query.parsedPlan, query.origin, isInternal = false) val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] - val finalExec = identifierOpt match { + val finalExec = variableNameOpt match { case None => bodyExec - case Some(identifier) => + case Some(variableName) => val dropVariableExec = new SingleStatementExec( - DropVariable(UnresolvedIdentifier(Seq(identifier)), ifExists = true), + DropVariable(UnresolvedIdentifier(Seq(variableName)), ifExists = true), Origin(), - isInternal = true) + isInternal = true + ) new CompoundBodyExec(Seq(bodyExec, dropVariableExec)) } - - new ForStatementExec(queryExec, identifierOpt, finalExec, label, session) + new ForStatementExec(queryExec, variableNameOpt, finalExec, label, session) case leaveStatement: LeaveStatement => new LeaveStatementExec(leaveStatement.label) From f3e8a9c71d4ef9b5e6e6bcfd7b1ea090ed48a7d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 7 Nov 2024 19:46:18 +0100 Subject: [PATCH 14/39] change iterator to seq --- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 4 ++-- .../sql/scripting/SqlScriptingInterpreterSuite.scala | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 0e2d495a5ff89..e845b18233f41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -734,9 +734,9 @@ class ForStatementExec( private def createExpressionFromValue(value: Any): Expression = value match { case m: Map[_, _] => // arguments of CreateMap are in the format: (key1, val1, key2, val2, ...) - val mapArgs = m.keys.zip(m.values).flatMap {case (k, v) => + val mapArgs = m.keys.toSeq.zip(m.values.toSeq).flatMap { case (k, v) => Seq(createExpressionFromValue(k), createExpressionFromValue(v)) - }.toSeq + } CreateMap(mapArgs, false) case s: GenericRowWithSchema => // struct types match this case diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index d4e95302246a5..52be809e18039 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -61,7 +61,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } test("testetst") { - val sqlScript = "DECLARE my_map DEFAULT MAP('x', 0, 'y', 0);" + val sqlScript = "DECLARE my_map DEFAULT MAP(1,1);" verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) } @@ -109,10 +109,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val sqlScript = """ |BEGIN - | CREATE TABLE t (int_column INT, map_column MAP>, struct_column STRUCT, array_column ARRAY); + | CREATE TABLE t (int_column INT, map_column MAP>, struct_column STRUCT, array_column ARRAY); | INSERT INTO t VALUES - | (1, MAP('a', MAP('1', 1)), STRUCT('John', 25), ARRAY('apple', 'banana')), - | (1, MAP('b', MAP('2', 2)), STRUCT('Jane', 30), ARRAY('apple', 'banana')); + | (1, MAP('a', MAP(1, 1)), STRUCT('John', 25), ARRAY('apple', 'banana')), + | (1, MAP('b', MAP(2, 2)), STRUCT('Jane', 30), ARRAY('apple', 'banana')); | FOR row AS SELECT * FROM t ORDER BY int_column DO | SELECT row.map_column; | SELECT row.struct_column; From 18788c9e05005fcbc3bb522ed5667326533c597b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 7 Nov 2024 20:09:57 +0100 Subject: [PATCH 15/39] improve iteration logic for map args --- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index e845b18233f41..73e7af525788b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -734,8 +734,8 @@ class ForStatementExec( private def createExpressionFromValue(value: Any): Expression = value match { case m: Map[_, _] => // arguments of CreateMap are in the format: (key1, val1, key2, val2, ...) - val mapArgs = m.keys.toSeq.zip(m.values.toSeq).flatMap { case (k, v) => - Seq(createExpressionFromValue(k), createExpressionFromValue(v)) + val mapArgs = m.keys.toSeq.flatMap { key => + Seq(createExpressionFromValue(key), createExpressionFromValue(m(key))) } CreateMap(mapArgs, false) case s: GenericRowWithSchema => From b3c7145770e6aa93306a2883aa224fde30e25576 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 8 Nov 2024 11:38:07 +0100 Subject: [PATCH 16/39] add parser test --- .../sql/catalyst/parser/AstBuilder.scala | 5 +- .../parser/SqlScriptingParserSuite.scala | 233 +++++++++++++++++- .../scripting/SqlScriptingExecutionNode.scala | 4 +- 3 files changed, 239 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a721aaa30b36e..e8a6c4581d595 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -350,7 +350,10 @@ class AstBuilder extends DataTypeAstBuilder override def visitForStatement(ctx: ForStatementContext): ForStatement = { val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel())) - val query = SingleStatement(visitQuery(ctx.query())) + val queryCtx = ctx.query() + val query = withOrigin(queryCtx) { + SingleStatement(visitQuery(queryCtx)) + } val varName = Option(ctx.multipartIdentifier()).map(_.getText) val body = visitCompoundBody(ctx.compoundBody()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index de4517120ce84..cd7df6255ffe0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -1206,7 +1206,6 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { head.asInstanceOf[SingleStatement].getText == "SELECT 42") assert(whileStmt.label.contains("lbl")) - } test("searched case statement") { @@ -1937,6 +1936,238 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(repeatStatement.label.get == "lbl_3") } + test("for statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR x AS SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.contains("x")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement no label") { + val sqlScriptText = + """ + |BEGIN + | FOR x AS SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.contains("x")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + // when not explicitly set, label is random UUID + assert(forStmt.label.isDefined) + } + + test("for statement with complex subquery") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR x AS SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1 DO + | SELECT x.c1; + | SELECT x.c2; + | END FOR; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1") + assert(forStmt.variableName.contains("x")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 2) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT x.c1") + assert(forStmt.body.collection(1).isInstanceOf[SingleStatement]) + assert(forStmt.body.collection(1).asInstanceOf[SingleStatement].getText == "SELECT x.c2") + + assert(forStmt.label.contains("lbl")) + } + + test("nested for statement") { + val sqlScriptText = + """ + |BEGIN + | lbl1: FOR i AS SELECT 1 DO + | lbl2: FOR j AS SELECT 2 DO + | SELECT i + j; + | END FOR lbl2; + | END FOR lbl1; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 1") + assert(forStmt.variableName.contains("i")) + assert(forStmt.label.contains("lbl1")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[ForStatement]) + val nestedForStmt = forStmt.body.collection.head.asInstanceOf[ForStatement] + + assert(nestedForStmt.query.isInstanceOf[SingleStatement]) + assert(nestedForStmt.query.getText == "SELECT 2") + assert(nestedForStmt.variableName.contains("j")) + assert(nestedForStmt.label.contains("lbl2")) + + assert(nestedForStmt.body.isInstanceOf[CompoundBody]) + assert(nestedForStmt.body.collection.length == 1) + assert(nestedForStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedForStmt.body.collection. + head.asInstanceOf[SingleStatement].getText == "SELECT i + j") + } + + test("for statement no variable") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.isEmpty) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(forStmt.label.contains("lbl")) + } + + test("for statement no label no variable") { + val sqlScriptText = + """ + |BEGIN + | FOR SELECT 5 DO + | SELECT 1; + | END FOR; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 5") + assert(forStmt.variableName.isEmpty) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + // when not explicitly set, label is random UUID + assert(forStmt.label.isDefined) + } + + test("for statement with complex subquery no variable") { + val sqlScriptText = + """ + |BEGIN + | lbl: FOR SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1 DO + | SELECT 1; + | SELECT 2; + | END FOR; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT c1, c2 FROM t WHERE c2 = 5 GROUP BY c1 ORDER BY c1") + assert(forStmt.variableName.isEmpty) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 2) + assert(forStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(forStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + assert(forStmt.body.collection(1).isInstanceOf[SingleStatement]) + assert(forStmt.body.collection(1).asInstanceOf[SingleStatement].getText == "SELECT 2") + + assert(forStmt.label.contains("lbl")) + } + + test("nested for statement no variable") { + val sqlScriptText = + """ + |BEGIN + | lbl1: FOR SELECT 1 DO + | lbl2: FOR SELECT 2 DO + | SELECT 3; + | END FOR lbl2; + | END FOR lbl1; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[ForStatement]) + + val forStmt = tree.collection.head.asInstanceOf[ForStatement] + assert(forStmt.query.isInstanceOf[SingleStatement]) + assert(forStmt.query.getText == "SELECT 1") + assert(forStmt.variableName.isEmpty) + assert(forStmt.label.contains("lbl1")) + + assert(forStmt.body.isInstanceOf[CompoundBody]) + assert(forStmt.body.collection.length == 1) + assert(forStmt.body.collection.head.isInstanceOf[ForStatement]) + val nestedForStmt = forStmt.body.collection.head.asInstanceOf[ForStatement] + + assert(nestedForStmt.query.isInstanceOf[SingleStatement]) + assert(nestedForStmt.query.getText == "SELECT 2") + assert(nestedForStmt.variableName.isEmpty) + assert(nestedForStmt.label.contains("lbl2")) + + assert(nestedForStmt.body.isInstanceOf[CompoundBody]) + assert(nestedForStmt.body.collection.length == 1) + assert(nestedForStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedForStmt.body.collection. + head.asInstanceOf[SingleStatement].getText == "SELECT 3") + } + // Helper methods def cleanupStatementString(statementStr: String): String = { statementStr diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 73e7af525788b..8c641b4170986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -693,7 +693,6 @@ class ForStatementExec( return next() } currVariable = createExpressionFromValue(cachedQueryResult()(currRow)) - state = ForState.VariableAssignment createDeclareVarExec(variableName.get, currVariable) @@ -731,6 +730,9 @@ class ForStatementExec( } } + /** + * Creates a Catalyst expression from Scala value. + */ private def createExpressionFromValue(value: Any): Expression = value match { case m: Map[_, _] => // arguments of CreateMap are in the format: (key1, val1, key2, val2, ...) From 0eb6184657de22c5fee46432edd0bc463019eed0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 8 Nov 2024 12:21:33 +0100 Subject: [PATCH 17/39] identation --- .../plans/logical/SqlScriptingLogicalPlans.scala | 8 ++++---- .../sql/scripting/SqlScriptingExecutionNode.scala | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index ce8d721b23eeb..63e919088ece4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -280,7 +280,7 @@ case class LoopStatement( * The label can be used to LEAVE or ITERATE the loop. */ case class ForStatement( - query: SingleStatement, - variableName: Option[String], - body: CompoundBody, - label: Option[String]) extends CompoundPlanStatement + query: SingleStatement, + variableName: Option[String], + body: CompoundBody, + label: Option[String]) extends CompoundPlanStatement diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 8c641b4170986..6d7f9de28b873 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -650,11 +650,11 @@ class LoopStatementExec( * @param session Spark session that SQL script is executed within. */ class ForStatementExec( - query: SingleStatementExec, - variableName: Option[String], - body: CompoundBodyExec, - label: Option[String], - session: SparkSession) extends NonLeafStatementExec { + query: SingleStatementExec, + variableName: Option[String], + body: CompoundBodyExec, + label: Option[String], + session: SparkSession) extends NonLeafStatementExec { private object ForState extends Enumeration { val VariableDeclaration, VariableAssignment, Body = Value From 338bb813c0a3f481553a86e1efb6a5329144e48a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 8 Nov 2024 15:41:39 +0100 Subject: [PATCH 18/39] execution node tests --- .../scripting/SqlScriptingExecutionNode.scala | 18 +- .../SqlScriptingExecutionNodeSuite.scala | 189 +++++++++++++++++- 2 files changed, 199 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 6d7f9de28b873..4c2e8f73789c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, GenericRowWithSchema, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} @@ -125,6 +125,17 @@ class SingleStatementExec( */ var isExecuted = false + /** + * Builds a DataFrame from the parsedPlan of this SingleStatementExec + * @param session The SparkSession on which the parsedPlan is built + * @return + * The DataFrame. + */ + def buildDataFrame(session: SparkSession): DataFrame = { + isExecuted = true + Dataset.ofRows(session, parsedPlan) + } + /** * Get the SQL query text corresponding to this statement. * @return @@ -653,7 +664,7 @@ class ForStatementExec( query: SingleStatementExec, variableName: Option[String], body: CompoundBodyExec, - label: Option[String], + val label: Option[String], session: SparkSession) extends NonLeafStatementExec { private object ForState extends Enumeration { @@ -667,8 +678,7 @@ class ForStatementExec( private var isResultCacheValid = false private def cachedQueryResult(): Array[Row] = { if (!isResultCacheValid) { - query.isExecuted = true - queryResult = Dataset.ofRows(session, query.parsedPlan).collect() + queryResult = query.buildDataFrame(session).collect() isResultCacheValid = true } queryResult diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index baad5702f4f22..098d9ddadcd81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, OneRowRelation, Project} import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} /** * Unit tests for execution nodes from SqlScriptingExecutionNode.scala. @@ -80,9 +81,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi } case class TestRepeat( - condition: TestLoopCondition, - body: CompoundBodyExec, - label: Option[String] = None) + condition: TestLoopCondition, + body: CompoundBodyExec, + label: Option[String] = None) extends RepeatStatementExec(condition, body, label, spark) { private val evaluator = new LoopBooleanConditionEvaluator(condition) @@ -92,6 +93,22 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi statement: LeafStatementExec): Boolean = evaluator.evaluateLoopBooleanCondition() } + case class TestForStatementQuery(numberOfRows: Int, description: String) + extends SingleStatementExec( + DummyLogicalPlan(), + Origin(startIndex = Some(0), stopIndex = Some(description.length)), + isInternal = false) { + override def buildDataFrame(session: SparkSession): DataFrame = { + val data = Seq.range(0, numberOfRows).map(Row(_)) + val schema = List(StructField("intCol", IntegerType)) + + spark.createDataFrame( + spark.sparkContext.parallelize(data), + StructType(schema) + ) + } + } + private def extractStatementValue(statement: CompoundStatementExec): String = statement match { case TestLeafStatement(testVal) => testVal @@ -100,6 +117,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case loopStmt: LoopStatementExec => loopStmt.label.get case leaveStmt: LeaveStatementExec => leaveStmt.label case iterateStmt: IterateStatementExec => iterateStmt.label + case forStmt: ForStatementExec => forStmt.label.get + case _: SingleStatementExec => "SingleStatementExec" case _ => fail("Unexpected statement type") } @@ -686,4 +705,166 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq("body1", "lbl")) } + + test("for statement enters body once") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(1, "query1"), + variableName = Some("x"), + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), + label = Some("for1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "SingleStatementExec", // declare var + "SingleStatementExec", // set var + "body" + )) + } + + test("for statement enters body with multiple statements multiple times") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = Some("x"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + TestLeafStatement("statement2"))), + label = Some("for1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "SingleStatementExec", // declare var + "SingleStatementExec", // set var + "statement1", + "statement2", + "SingleStatementExec", // declare var + "SingleStatementExec", // set var + "statement1", + "statement2" + )) + } + + test("for statement empty result") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(0, "query1"), + variableName = Some("x"), + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + label = Some("for1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq.empty[String]) + } + + test("for statement nested") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = Some("x"), + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query2"), + variableName = Some("y"), + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), + label = Some("for2"), + session = spark + ) + )), + label = Some("for1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "SingleStatementExec", // declare x + "SingleStatementExec", // set x + "SingleStatementExec", // declare y + "SingleStatementExec", // set y + "body", + "SingleStatementExec", // declare y + "SingleStatementExec", // set y + "body", + "SingleStatementExec", // declare x + "SingleStatementExec", // set x + "SingleStatementExec", // declare y + "SingleStatementExec", // set y + "body", + "SingleStatementExec", // declare y + "SingleStatementExec", // set y + "body" + )) + } + + test("for statement no variable enters body once") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(1, "query1"), + variableName = None, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), + label = Some("for1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body")) + } + + test("for statement no variable enters body with multiple statements multiple times") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = None, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + TestLeafStatement("statement2"))), + label = Some("for1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("statement1", "statement2", "statement1", "statement2")) + } + + test("for statement no variable empty result") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(0, "query1"), + variableName = None, + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))), + label = Some("for1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq.empty[String]) + } + + test("for statement no variable nested") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = None, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query2"), + variableName = None, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), + label = Some("for2"), + session = spark + ) + )), + label = Some("for1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body", "body", "body", "body")) + } } From 4910fc2ebbf5325346a0b2f51f52b2f6a6417c48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 8 Nov 2024 16:52:42 +0100 Subject: [PATCH 19/39] execution node tests - iterate and elave --- .../SqlScriptingExecutionNodeSuite.scala | 218 +++++++++++++++++- 1 file changed, 210 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 098d9ddadcd81..accab422ccf56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -706,7 +706,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi assert(statements === Seq("body1", "lbl")) } - test("for statement enters body once") { + test("for statement - enters body once") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( query = TestForStatementQuery(1, "query1"), @@ -724,7 +724,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )) } - test("for statement enters body with multiple statements multiple times") { + test("for statement - enters body with multiple statements multiple times") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( query = TestForStatementQuery(2, "query1"), @@ -749,7 +749,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )) } - test("for statement empty result") { + test("for statement - empty result") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( query = TestForStatementQuery(0, "query1"), @@ -763,7 +763,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi assert(statements === Seq.empty[String]) } - test("for statement nested") { + test("for statement - nested") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( query = TestForStatementQuery(2, "query1"), @@ -802,7 +802,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )) } - test("for statement no variable enters body once") { + test("for statement no variable - enters body once") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( query = TestForStatementQuery(1, "query1"), @@ -816,7 +816,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi assert(statements === Seq("body")) } - test("for statement no variable enters body with multiple statements multiple times") { + test("for statement no variable - enters body with multiple statements multiple times") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( query = TestForStatementQuery(2, "query1"), @@ -832,7 +832,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi assert(statements === Seq("statement1", "statement2", "statement1", "statement2")) } - test("for statement no variable empty result") { + test("for statement no variable - empty result") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( query = TestForStatementQuery(0, "query1"), @@ -846,7 +846,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi assert(statements === Seq.empty[String]) } - test("for statement no variable nested") { + test("for statement no variable - nested") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( query = TestForStatementQuery(2, "query1"), @@ -867,4 +867,206 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq("body", "body", "body", "body")) } + + test("for statement - iterate") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = Some("x"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("statement2"))), + label = Some("lbl1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "SingleStatementExec", // declare var + "SingleStatementExec", // set var + "statement1", + "lbl1", + "SingleStatementExec", // declare var + "SingleStatementExec", // set var + "statement1", + "lbl1" + )) + } + + test("for statement - leave") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = Some("x"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("statement2"))), + label = Some("lbl1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "SingleStatementExec", // declare var + "SingleStatementExec", // set var + "statement1", + "lbl1" + )) + } + + test("for statement - nested - iterate outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = Some("x"), + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query2"), + variableName = Some("y"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("body2"))), + label = Some("lbl2"), + session = spark + ) + )), + label = Some("lbl1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "SingleStatementExec", // declare x + "SingleStatementExec", // set x + "SingleStatementExec", // declare y + "SingleStatementExec", // set y + "body1", + "lbl1", + "SingleStatementExec", // declare x + "SingleStatementExec", // set x + "SingleStatementExec", // declare y + "SingleStatementExec", // set y + "body1", + "lbl1" + )) + } + + test("for statement - nested - leave outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = Some("x"), + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query2"), + variableName = Some("y"), + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("body2"))), + label = Some("lbl2"), + session = spark + ) + )), + label = Some("lbl1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq( + "SingleStatementExec", // declare x + "SingleStatementExec", // set x + "SingleStatementExec", // declare y + "SingleStatementExec", // set y + "body1", + "lbl1" + )) + } + + test("for statement no variable - iterate") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = None, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("statement2"))), + label = Some("lbl1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("statement1", "lbl1", "statement1", "lbl1")) + } + + test("for statement no variable - leave") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = None, + body = new CompoundBodyExec(Seq( + TestLeafStatement("statement1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("statement2"))), + label = Some("lbl1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("statement1", "lbl1")) + } + + test("for statement no variable - nested - iterate outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = None, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query2"), + variableName = None, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new IterateStatementExec("lbl1"), + TestLeafStatement("body2"))), + label = Some("lbl2"), + session = spark + ) + )), + label = Some("lbl1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl1", "body1", "lbl1")) + } + + test("for statement no variable - nested - leave outer loop") { + val iter = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query1"), + variableName = None, + body = new CompoundBodyExec(Seq( + new ForStatementExec( + query = TestForStatementQuery(2, "query2"), + variableName = None, + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl1"), + TestLeafStatement("body2"))), + label = Some("lbl2"), + session = spark + ) + )), + label = Some("lbl1"), + session = spark + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl1")) + } } From 9c01b57c62d90ae00d17c723dfcce2531047b5f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 8 Nov 2024 18:59:50 +0100 Subject: [PATCH 20/39] start interpreter tests --- .../scripting/SqlScriptingExecutionNode.scala | 1 - .../SqlScriptingInterpreterSuite.scala | 592 +++++++++--------- 2 files changed, 295 insertions(+), 298 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 4c2e8f73789c6..328f775919f5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType - /** * Trait for all SQL scripting execution nodes used during interpretation phase. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 52be809e18039..4965c2e7ff4cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -60,303 +60,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { result.zip(expected).foreach { case (df, expectedAnswer) => checkAnswer(df, expectedAnswer) } } - test("testetst") { - val sqlScript = "DECLARE my_map DEFAULT MAP(1,1);" - verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) - } - - - // todo: complex types in for, better var names in tests - test("for test") { - withTable("t") { - val sqlScript = - """ - |BEGIN - | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; - | INSERT INTO t VALUES (1, 'first', 1.0); - | INSERT INTO t VALUES (2, 'second', 2.0); - | FOR x AS SELECT * FROM t ORDER BY intCol DO - | SELECT x.intCol; - | SELECT x.stringCol; - | SELECT x.doubleCol; - | END FOR; - |END - |""".stripMargin - - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 - Seq(Row(1)), // select intCol - Seq(Row("first")), // select stringCol - Seq(Row(1.0)), // select doubleCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 - Seq(Row(2)), // select intCol - Seq(Row("second")), // select stringCol - Seq(Row(2.0)), // select doubleCol - Seq.empty[Row] // drop x - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("for test complex types") { - withTable("t") { - val sqlScript = - """ - |BEGIN - | CREATE TABLE t (int_column INT, map_column MAP>, struct_column STRUCT, array_column ARRAY); - | INSERT INTO t VALUES - | (1, MAP('a', MAP(1, 1)), STRUCT('John', 25), ARRAY('apple', 'banana')), - | (1, MAP('b', MAP(2, 2)), STRUCT('Jane', 30), ARRAY('apple', 'banana')); - | FOR row AS SELECT * FROM t ORDER BY int_column DO - | SELECT row.map_column; - | SELECT row.struct_column; - | SELECT row.array_column; - | END FOR; - |END - |""".stripMargin - - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 - Seq(Row(1)), // select intCol - Seq(Row("first")), // select stringCol - Seq(Row(1.0)), // select doubleCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 - Seq(Row(2)), // select intCol - Seq(Row("second")), // select stringCol - Seq(Row(2.0)), // select doubleCol - Seq.empty[Row] // drop x - ) - verifySqlScriptResult(sqlScript, expected) - } - } - -// test("for test complex types") { -// withTable("t") { -// val sqlScript = -// """ -// |BEGIN -// | CREATE TABLE t (int_column INT, map_column MAP, struct_column STRUCT, array_column ARRAY); -// | INSERT INTO t VALUES -// | (1, MAP('a', 1, 'b', 2), STRUCT('John', 25), ARRAY('apple', 'banana')), -// | (2, MAP('c', 3, 'd', 4), STRUCT('Jane', 30), ARRAY('orange', 'grape')), -// | (3, MAP('e', 5, 'f', 6), STRUCT('Bob', 35), ARRAY('pear', 'peach')); -// | FOR row AS SELECT * FROM t ORDER BY int_column DO -// | SELECT row.map_column; -// | SELECT row.struct_column; -// | SELECT row.array_column; -// | END FOR; -// |END -// |""".stripMargin -// -// val expected = Seq( -// Seq.empty[Row], // create table -// Seq.empty[Row], // insert -// Seq.empty[Row], // insert -// Seq.empty[Row], // declare x -// Seq.empty[Row], // set x to row 0 -// Seq(Row(1)), // select intCol -// Seq(Row("first")), // select stringCol -// Seq(Row(1.0)), // select doubleCol -// Seq.empty[Row], // drop x -// Seq.empty[Row], // declare x -// Seq.empty[Row], // set x to row 1 -// Seq(Row(2)), // select intCol -// Seq(Row("second")), // select stringCol -// Seq(Row(2.0)), // select doubleCol -// Seq.empty[Row] // drop x -// ) -// verifySqlScriptResult(sqlScript, expected) -// } -// } - - test("for test empty result") { - withTable("t") { - val sqlScript = - """ - |BEGIN - | CREATE TABLE t (intCol INT) using parquet; - | FOR x AS SELECT * FROM t ORDER BY intCol DO - | SELECT x.intCol; - | END FOR; - |END - |""".stripMargin - - val expected = Seq( - Seq.empty[Row], // create table - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("for test iterate") { - withTable("t") { - val sqlScript = - """ - |BEGIN - | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; - | INSERT INTO t VALUES (1, 'first'); - | INSERT INTO t VALUES (2, 'second'); - | INSERT INTO t VALUES (3, 'third'); - | INSERT INTO t VALUES (4, 'fourth'); - | - | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO - | IF x.intCol = 2 THEN - | ITERATE lbl; - | END IF; - | SELECT x.stringCol; - | END FOR; - |END - |""".stripMargin - - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 - Seq(Row("first")), // select stringCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 -// Seq.empty[Row], // drop x - TODO: uncomment when iterate can handle dropping vars - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 2 - Seq(Row("third")), // select stringCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 3 - Seq(Row("fourth")), // select stringCol - Seq.empty[Row], // drop x - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("for test leave") { - withTable("t") { - val sqlScript = - """ - |BEGIN - | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; - | INSERT INTO t VALUES (1, 'first'); - | INSERT INTO t VALUES (2, 'second'); - | INSERT INTO t VALUES (3, 'third'); - | INSERT INTO t VALUES (4, 'fourth'); - | - | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO - | IF x.intCol = 3 THEN - | LEAVE lbl; - | END IF; - | SELECT x.stringCol; - | END FOR; - |END - |""".stripMargin - - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 - Seq(Row("first")), // select stringCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 - Seq(Row("second")), // select stringCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 2 -// Seq.empty[Row], // drop x - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("for test nested in while") { - withTable("t") { - val sqlScript = - """ - |BEGIN - | DECLARE i = 0; - | CREATE TABLE t (intCol INT) using parquet; - | INSERT INTO t VALUES (0); - | WHILE i < 2 DO - | SET i = i + 1; - | FOR x AS SELECT * FROM t ORDER BY intCol DO - | SELECT x.intCol; - | END FOR; - | INSERT INTO t VALUES (i); - | END WHILE; - |END - |""".stripMargin - - val expected = Seq( - Seq.empty[Row], // declare i - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // set i - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 - Seq(Row(0)), // select intCol - Seq.empty[Row], // drop x - Seq.empty[Row], // insert - Seq.empty[Row], // set i - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 - Seq(Row(0)), // select intCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 - Seq(Row(1)), // select intCol - Seq.empty[Row], // drop x - Seq.empty[Row], // insert - Seq.empty[Row], // drop i - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("for test no variable") { - withTable("t") { - val sqlScript = - """ - |BEGIN - | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; - | INSERT INTO t VALUES (1, 'first', 1.0); - | INSERT INTO t VALUES (2, 'second', 2.0); - | FOR SELECT * FROM t ORDER BY intCol DO - | SELECT 1; - | END FOR; - |END - |""".stripMargin - - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq(Row(1)), // select 1 - Seq(Row(1)), // select 1 - ) - verifySqlScriptResult(sqlScript, expected) - } - } - // Tests test("multi statement - simple") { withTable("t") { @@ -1844,4 +1547,299 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { ) verifySqlScriptResult(sqlScriptText, expected) } + + test("testetst") { + val sqlScript = "DECLARE my_map DEFAULT MAP(1,1);" + verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + } + + // todo: complex types in for, better var names in tests + test("for test") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | INSERT INTO t VALUES (2, 'second', 2.0); + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | SELECT x.stringCol; + | SELECT x.doubleCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row(1)), // select intCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select doubleCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 1 + Seq(Row(2)), // select intCol + Seq(Row("second")), // select stringCol + Seq(Row(2.0)), // select doubleCol + Seq.empty[Row] // drop x + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for test complex types") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP>, struct_column STRUCT, array_column ARRAY); + | INSERT INTO t VALUES + | (1, MAP('a', MAP(1, 10)), STRUCT('John', 25), ARRAY('apple', 'banana')), + | (2, MAP('b', MAP(2, 20)), STRUCT('Jane', 30), ARRAY('apple', 'banana')); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.map_column; + | SELECT row.struct_column; + | SELECT row.array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row(Map("a" -> Map(1 -> 10)))), // select map_column + Seq(Row(Row("John", 25))), // select struct_column + Seq(Row(Array("apple", "banana"))), // select array_column + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 1 + Seq(Row(Map("b" -> Map(2 -> 20)))), // select map_column + Seq(Row(Row("Jane", 30))), // select struct_column + Seq(Row(Array("apple", "banana"))), // select array_column + Seq.empty[Row] // drop x + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + // test("for test complex types") { + // withTable("t") { + // val sqlScript = + // """ + // |BEGIN + // | CREATE TABLE t (int_column INT, map_column MAP, struct_column STRUCT, array_column ARRAY); + // | INSERT INTO t VALUES + // | (1, MAP('a', 1, 'b', 2), STRUCT('John', 25), ARRAY('apple', 'banana')), + // | (2, MAP('c', 3, 'd', 4), STRUCT('Jane', 30), ARRAY('orange', 'grape')), + // | (3, MAP('e', 5, 'f', 6), STRUCT('Bob', 35), ARRAY('pear', 'peach')); + // | FOR row AS SELECT * FROM t ORDER BY int_column DO + // | SELECT row.map_column; + // | SELECT row.struct_column; + // | SELECT row.array_column; + // | END FOR; + // |END + // |""".stripMargin + // + // val expected = Seq( + // Seq.empty[Row], // create table + // Seq.empty[Row], // insert + // Seq.empty[Row], // insert + // Seq.empty[Row], // declare x + // Seq.empty[Row], // set x to row 0 + // Seq(Row(1)), // select intCol + // Seq(Row("first")), // select stringCol + // Seq(Row(1.0)), // select doubleCol + // Seq.empty[Row], // drop x + // Seq.empty[Row], // declare x + // Seq.empty[Row], // set x to row 1 + // Seq(Row(2)), // select intCol + // Seq(Row("second")), // select stringCol + // Seq(Row(2.0)), // select doubleCol + // Seq.empty[Row] // drop x + // ) + // verifySqlScriptResult(sqlScript, expected) + // } + // } + + test("for test empty result") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for test iterate") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'); + | INSERT INTO t VALUES (2, 'second'); + | INSERT INTO t VALUES (3, 'third'); + | INSERT INTO t VALUES (4, 'fourth'); + | + | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO + | IF x.intCol = 2 THEN + | ITERATE lbl; + | END IF; + | SELECT x.stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row("first")), // select stringCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 1 + // Seq.empty[Row], // drop x - TODO: uncomment when iterate can handle dropping vars + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 2 + Seq(Row("third")), // select stringCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 3 + Seq(Row("fourth")), // select stringCol + Seq.empty[Row], // drop x + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for test leave") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'); + | INSERT INTO t VALUES (2, 'second'); + | INSERT INTO t VALUES (3, 'third'); + | INSERT INTO t VALUES (4, 'fourth'); + | + | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO + | IF x.intCol = 3 THEN + | LEAVE lbl; + | END IF; + | SELECT x.stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row("first")), // select stringCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 1 + Seq(Row("second")), // select stringCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 2 + // Seq.empty[Row], // drop x + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for test nested in while") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE i = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (0); + | WHILE i < 2 DO + | SET i = i + 1; + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | END FOR; + | INSERT INTO t VALUES (i); + | END WHILE; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare i + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set i + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row(0)), // select intCol + Seq.empty[Row], // drop x + Seq.empty[Row], // insert + Seq.empty[Row], // set i + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 0 + Seq(Row(0)), // select intCol + Seq.empty[Row], // drop x + Seq.empty[Row], // declare x + Seq.empty[Row], // set x to row 1 + Seq(Row(1)), // select intCol + Seq.empty[Row], // drop x + Seq.empty[Row], // insert + Seq.empty[Row], // drop i + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for test no variable") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | INSERT INTO t VALUES (2, 'second', 2.0); + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT 1; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(1)), // select 1 + Seq(Row(1)), // select 1 + ) + verifySqlScriptResult(sqlScript, expected) + } + } } From 616c94cb8ed2278d11fa4578a7efa45d8e7c4233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 12 Nov 2024 18:18:27 +0100 Subject: [PATCH 21/39] refactor to support column access without qualifying --- .../scripting/SqlScriptingExecutionNode.scala | 64 +++++++++++------ .../scripting/SqlScriptingInterpreter.scala | 12 +--- .../SqlScriptingInterpreterSuite.scala | 70 ++++++------------- 3 files changed, 67 insertions(+), 79 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 328f775919f5a..aebfa29f5cdbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, GenericRowWithSchema, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors @@ -667,11 +667,15 @@ class ForStatementExec( session: SparkSession) extends NonLeafStatementExec { private object ForState extends Enumeration { - val VariableDeclaration, VariableAssignment, Body = Value + val VariableAssignment, Body = Value } - private var state = ForState.VariableDeclaration + private var state = ForState.VariableAssignment private var currRow = 0 - private var currVariable: Expression = null + private var areVariablesDeclared = false + + // map of all variables created internally by the for statement + // (variableName -> variableExpression) + private var variablesMap: Map[String, Expression] = Map() private var queryResult: Array[Row] = null private var isResultCacheValid = false @@ -694,21 +698,22 @@ class ForStatementExec( !interrupted && cachedQueryResult().length > 0 && currRow < cachedQueryResult().length override def next(): CompoundStatementExec = state match { - case ForState.VariableDeclaration => - // when there is no for variable, skip var declaration and iterate only the body - if (variableName.isEmpty) { - state = ForState.Body - body.reset() - return next() - } - currVariable = createExpressionFromValue(cachedQueryResult()(currRow)) - state = ForState.VariableAssignment - createDeclareVarExec(variableName.get, currVariable) case ForState.VariableAssignment => + variablesMap = createVariablesMapFromRow(currRow) + + if (!areVariablesDeclared) { + variablesMap.keys.toSeq + .map(colName => createDeclareVarExec(colName, variablesMap(colName))) + .foreach(declareVarExec => declareVarExec.buildDataFrame(session).collect()) + areVariablesDeclared = true + } + variablesMap.keys.toSeq + .map(colName => createSetVarExec(colName, variablesMap(colName))) + .foreach(exec => exec.buildDataFrame(session).collect()) state = ForState.Body body.reset() - createSetVarExec(variableName.get, currVariable) + next() case ForState.Body => val retStmt = body.getTreeIterator.next() @@ -720,20 +725,21 @@ class ForStatementExec( leaveStatementExec.hasBeenMatched = true } interrupted = true + // drop vars return retStmt case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => if (label.contains(iterStatementExec.label)) { iterStatementExec.hasBeenMatched = true } currRow += 1 - state = ForState.VariableDeclaration + state = ForState.VariableAssignment return retStmt case _ => } if (!body.getTreeIterator.hasNext) { currRow += 1 - state = ForState.VariableDeclaration + state = ForState.VariableAssignment } retStmt } @@ -749,7 +755,7 @@ class ForStatementExec( Seq(createExpressionFromValue(key), createExpressionFromValue(m(key))) } CreateMap(mapArgs, false) - case s: GenericRowWithSchema => + case s: Row => // struct types match this case // arguments of CreateNamedStruct are in the format: (name1, val1, name2, val2, ...) val namedStructArgs = s.schema.names.toSeq.flatMap { colName => @@ -760,6 +766,22 @@ class ForStatementExec( case _ => Literal(value) } + private def createVariablesMapFromRow(rowIndex: Int): Map[String, Expression] = { + val row = cachedQueryResult()(rowIndex) + var variablesMap = row.schema.names.toSeq.map { colName => + colName -> createExpressionFromValue(row.getAs(colName)) + }.toMap + + if (variableName.isDefined) { + val namedStructArgs = variablesMap.keys.toSeq.flatMap { colName => + Seq(Literal(colName), variablesMap(colName)) + } + val forVariable = CreateNamedStruct(namedStructArgs) + variablesMap = variablesMap + (variableName.get -> forVariable) + } + variablesMap + } + private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = { val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null") val declareVariable = CreateVariable( @@ -776,17 +798,17 @@ class ForStatementExec( OneRowRelation() ) val setIdentifierToCurrentRow = - SetVariable(Seq(UnresolvedAttribute(variableName.get)), projectNamedStruct) + SetVariable(Seq(UnresolvedAttribute(varName)), projectNamedStruct) new SingleStatementExec(setIdentifierToCurrentRow, Origin(), isInternal = true) } override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator override def reset(): Unit = { - state = ForState.VariableDeclaration + state = ForState.VariableAssignment isResultCacheValid = false currRow = 0 - currVariable = null + variablesMap = Map() body.reset() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 1a182a152b787..3b6dcd63f0579 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -127,17 +127,7 @@ case class SqlScriptingInterpreter() { val queryExec = new SingleStatementExec(query.parsedPlan, query.origin, isInternal = false) val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] - val finalExec = variableNameOpt match { - case None => bodyExec - case Some(variableName) => - val dropVariableExec = new SingleStatementExec( - DropVariable(UnresolvedIdentifier(Seq(variableName)), ifExists = true), - Origin(), - isInternal = true - ) - new CompoundBodyExec(Seq(bodyExec, dropVariableExec)) - } - new ForStatementExec(queryExec, variableNameOpt, finalExec, label, session) + new ForStatementExec(queryExec, variableNameOpt, bodyExec, label, session) case leaveStatement: LeaveStatement => new LeaveStatementExec(leaveStatement.label) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 4965c2e7ff4cb..9f47027c57fa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1564,8 +1564,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | INSERT INTO t VALUES (2, 'second', 2.0); | FOR x AS SELECT * FROM t ORDER BY intCol DO | SELECT x.intCol; + | SELECT intCol; | SELECT x.stringCol; + | SELECT stringCol; | SELECT x.doubleCol; + | SELECT doubleCol; | END FOR; |END |""".stripMargin @@ -1574,18 +1577,18 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // create table Seq.empty[Row], // insert Seq.empty[Row], // insert - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 + Seq(Row(1)), // select x.intCol Seq(Row(1)), // select intCol + Seq(Row("first")), // select x.stringCol Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select x.doubleCol Seq(Row(1.0)), // select doubleCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 + Seq(Row(2)), // select x.intCol Seq(Row(2)), // select intCol + Seq(Row("second")), // select x.stringCol Seq(Row("second")), // select stringCol - Seq(Row(2.0)), // select doubleCol - Seq.empty[Row] // drop x + Seq(Row(2.0)), // select x.doubleCol + Seq(Row(2.0)) // select doubleCol ) verifySqlScriptResult(sqlScript, expected) } @@ -1602,8 +1605,11 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | (2, MAP('b', MAP(2, 20)), STRUCT('Jane', 30), ARRAY('apple', 'banana')); | FOR row AS SELECT * FROM t ORDER BY int_column DO | SELECT row.map_column; + | SELECT map_column; | SELECT row.struct_column; + | SELECT struct_column; | SELECT row.array_column; + | SELECT array_column; | END FOR; |END |""".stripMargin @@ -1611,18 +1617,18 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // create table Seq.empty[Row], // insert - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 + Seq(Row(Map("a" -> Map(1 -> 10)))), // select row.map_column Seq(Row(Map("a" -> Map(1 -> 10)))), // select map_column + Seq(Row(Row("John", 25))), // select row.struct_column Seq(Row(Row("John", 25))), // select struct_column + Seq(Row(Array("apple", "banana"))), // select row.array_column Seq(Row(Array("apple", "banana"))), // select array_column - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 + Seq(Row(Map("b" -> Map(2 -> 20)))), // select row.map_column Seq(Row(Map("b" -> Map(2 -> 20)))), // select map_column + Seq(Row(Row("Jane", 30))), // select row.struct_column Seq(Row(Row("Jane", 30))), // select struct_column - Seq(Row(Array("apple", "banana"))), // select array_column - Seq.empty[Row] // drop x + Seq(Row(Array("apple", "banana"))), // select row.array_column + Seq(Row(Array("apple", "banana"))) // select array_column ) verifySqlScriptResult(sqlScript, expected) } @@ -1712,21 +1718,9 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // insert Seq.empty[Row], // insert Seq.empty[Row], // insert - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 Seq(Row("first")), // select stringCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 - // Seq.empty[Row], // drop x - TODO: uncomment when iterate can handle dropping vars - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 2 Seq(Row("third")), // select stringCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 3 - Seq(Row("fourth")), // select stringCol - Seq.empty[Row], // drop x + Seq(Row("fourth")) // select stringCol ) verifySqlScriptResult(sqlScript, expected) } @@ -1758,17 +1752,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // insert Seq.empty[Row], // insert Seq.empty[Row], // insert - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 Seq(Row("first")), // select stringCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 - Seq(Row("second")), // select stringCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 2 - // Seq.empty[Row], // drop x + Seq(Row("second")) // select stringCol ) verifySqlScriptResult(sqlScript, expected) } @@ -1797,22 +1782,13 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // create table Seq.empty[Row], // insert Seq.empty[Row], // set i - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 Seq(Row(0)), // select intCol - Seq.empty[Row], // drop x Seq.empty[Row], // insert Seq.empty[Row], // set i - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 0 Seq(Row(0)), // select intCol - Seq.empty[Row], // drop x - Seq.empty[Row], // declare x - Seq.empty[Row], // set x to row 1 Seq(Row(1)), // select intCol - Seq.empty[Row], // drop x Seq.empty[Row], // insert - Seq.empty[Row], // drop i + Seq.empty[Row] // drop i ) verifySqlScriptResult(sqlScript, expected) } From 78eb903ecee5b48fb0424974573ed4028b7c5f12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 12 Nov 2024 18:33:33 +0100 Subject: [PATCH 22/39] update execution node test --- .../SqlScriptingExecutionNodeSuite.scala | 37 ------------------- .../SqlScriptingInterpreterSuite.scala | 4 +- 2 files changed, 3 insertions(+), 38 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index accab422ccf56..d750d6a6c8465 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -118,7 +118,6 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case leaveStmt: LeaveStatementExec => leaveStmt.label case iterateStmt: IterateStatementExec => iterateStmt.label case forStmt: ForStatementExec => forStmt.label.get - case _: SingleStatementExec => "SingleStatementExec" case _ => fail("Unexpected statement type") } @@ -718,8 +717,6 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( - "SingleStatementExec", // declare var - "SingleStatementExec", // set var "body" )) } @@ -738,12 +735,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( - "SingleStatementExec", // declare var - "SingleStatementExec", // set var "statement1", "statement2", - "SingleStatementExec", // declare var - "SingleStatementExec", // set var "statement1", "statement2" )) @@ -783,21 +776,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( - "SingleStatementExec", // declare x - "SingleStatementExec", // set x - "SingleStatementExec", // declare y - "SingleStatementExec", // set y "body", - "SingleStatementExec", // declare y - "SingleStatementExec", // set y "body", - "SingleStatementExec", // declare x - "SingleStatementExec", // set x - "SingleStatementExec", // declare y - "SingleStatementExec", // set y "body", - "SingleStatementExec", // declare y - "SingleStatementExec", // set y "body" )) } @@ -883,12 +864,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( - "SingleStatementExec", // declare var - "SingleStatementExec", // set var "statement1", "lbl1", - "SingleStatementExec", // declare var - "SingleStatementExec", // set var "statement1", "lbl1" )) @@ -909,8 +886,6 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( - "SingleStatementExec", // declare var - "SingleStatementExec", // set var "statement1", "lbl1" )) @@ -939,16 +914,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( - "SingleStatementExec", // declare x - "SingleStatementExec", // set x - "SingleStatementExec", // declare y - "SingleStatementExec", // set y "body1", "lbl1", - "SingleStatementExec", // declare x - "SingleStatementExec", // set x - "SingleStatementExec", // declare y - "SingleStatementExec", // set y "body1", "lbl1" )) @@ -977,10 +944,6 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( - "SingleStatementExec", // declare x - "SingleStatementExec", // set x - "SingleStatementExec", // declare y - "SingleStatementExec", // set y "body1", "lbl1" )) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 9f47027c57fa1..8890966e25d65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1803,7 +1803,9 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | INSERT INTO t VALUES (1, 'first', 1.0); | INSERT INTO t VALUES (2, 'second', 2.0); | FOR SELECT * FROM t ORDER BY intCol DO - | SELECT 1; + | SELECT intCol; + | SELECT stringCol; + | SELECT doubleCol; | END FOR; |END |""".stripMargin From f498a50e76df05e38be9aacb54e34ab9f5d668e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 13 Nov 2024 12:54:44 +0100 Subject: [PATCH 23/39] add drop variables --- .../scripting/SqlScriptingExecutionNode.scala | 24 ++++++++-- .../SqlScriptingExecutionNodeSuite.scala | 48 +++++++++---------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index aebfa29f5cdbd..4ecc4ec369e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} +import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType @@ -710,7 +710,7 @@ class ForStatementExec( } variablesMap.keys.toSeq .map(colName => createSetVarExec(colName, variablesMap(colName))) - .foreach(exec => exec.buildDataFrame(session).collect()) + .foreach(setVarExec => setVarExec.buildDataFrame(session).collect()) state = ForState.Body body.reset() next() @@ -725,7 +725,7 @@ class ForStatementExec( leaveStatementExec.hasBeenMatched = true } interrupted = true - // drop vars + dropVars() return retStmt case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => if (label.contains(iterStatementExec.label)) { @@ -740,6 +740,11 @@ class ForStatementExec( if (!body.getTreeIterator.hasNext) { currRow += 1 state = ForState.VariableAssignment + + // on final iteration, drop variables + if (currRow == cachedQueryResult().length) { + dropVars() + } } retStmt } @@ -782,6 +787,13 @@ class ForStatementExec( variablesMap } + private def dropVars() = { + variablesMap.keys.toSeq + .map(colName => createDropVarExec(colName)) + .foreach(dropVarExec => dropVarExec.buildDataFrame(session).collect()) + areVariablesDeclared = false + } + private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = { val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null") val declareVariable = CreateVariable( @@ -802,6 +814,11 @@ class ForStatementExec( new SingleStatementExec(setIdentifierToCurrentRow, Origin(), isInternal = true) } + private def createDropVarExec(varName: String): SingleStatementExec = { + val dropVar = DropVariable(UnresolvedIdentifier(Seq(varName)), ifExists = true) + new SingleStatementExec(dropVar, Origin(), isInternal = true) + } + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator override def reset(): Unit = { @@ -809,6 +826,7 @@ class ForStatementExec( isResultCacheValid = false currRow = 0 variablesMap = Map() + areVariablesDeclared = false body.reset() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index d750d6a6c8465..1588299fe0d9b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -93,14 +93,14 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi statement: LeafStatementExec): Boolean = evaluator.evaluateLoopBooleanCondition() } - case class TestForStatementQuery(numberOfRows: Int, description: String) + case class TestForStatementQuery(numberOfRows: Int, columnName: String, description: String) extends SingleStatementExec( DummyLogicalPlan(), Origin(startIndex = Some(0), stopIndex = Some(description.length)), isInternal = false) { override def buildDataFrame(session: SparkSession): DataFrame = { val data = Seq.range(0, numberOfRows).map(Row(_)) - val schema = List(StructField("intCol", IntegerType)) + val schema = List(StructField(columnName, IntegerType)) spark.createDataFrame( spark.sparkContext.parallelize(data), @@ -708,7 +708,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - enters body once") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(1, "query1"), + query = TestForStatementQuery(1, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for1"), @@ -724,7 +724,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - enters body with multiple statements multiple times") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -745,7 +745,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - empty result") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(0, "query1"), + query = TestForStatementQuery(0, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))), label = Some("for1"), @@ -759,11 +759,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - nested") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query2"), + query = TestForStatementQuery(2, "intCol1", "query2"), variableName = Some("y"), body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for2"), @@ -786,7 +786,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - enters body once") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(1, "query1"), + query = TestForStatementQuery(1, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for1"), @@ -800,7 +800,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - enters body with multiple statements multiple times") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -816,7 +816,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - empty result") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(0, "query1"), + query = TestForStatementQuery(0, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))), label = Some("for1"), @@ -830,11 +830,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - nested") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query2"), + query = TestForStatementQuery(2, "intCol1", "query2"), variableName = None, body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for2"), @@ -852,7 +852,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - iterate") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -874,7 +874,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - leave") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -894,11 +894,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - nested - iterate outer loop") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query2"), + query = TestForStatementQuery(2, "intCol1", "query2"), variableName = Some("y"), body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), @@ -924,11 +924,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - nested - leave outer loop") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query2"), + query = TestForStatementQuery(2, "intCol", "query2"), variableName = Some("y"), body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), @@ -952,7 +952,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - iterate") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -969,7 +969,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - leave") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -986,11 +986,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - nested - iterate outer loop") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query2"), + query = TestForStatementQuery(2, "intCol1", "query2"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), @@ -1011,11 +1011,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - nested - leave outer loop") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query1"), + query = TestForStatementQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "query2"), + query = TestForStatementQuery(2, "intCol1", "query2"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), From 755ebe4499a554044efa953d8487afa48b3f3de7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 14 Nov 2024 17:12:13 +0100 Subject: [PATCH 24/39] fix for nested arrays, and change drop variable logic to work with leave/iterate/normal case --- .../scripting/SqlScriptingExecutionNode.scala | 73 +++-- .../SqlScriptingExecutionNodeSuite.scala | 57 +++- .../SqlScriptingInterpreterSuite.scala | 299 +++++++++++------- 3 files changed, 281 insertions(+), 148 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 4ecc4ec369e5b..aa558b2e9e0dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedIdentifier} -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateMap, CreateNamedStruct, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, DropVariable, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors @@ -654,8 +654,7 @@ class LoopStatementExec( * Executable node for ForStatement. * @param query Executable node for the query. * @param variableName Name of variable used for accessing current row during iteration. - * @param body Executable node for the body. If variableName is not None, will have DropVariable - * as the last statement. + * @param body Executable node for the body. * @param label Label set to ForStatement by user or None otherwise. * @param session Spark session that SQL script is executed within. */ @@ -667,7 +666,7 @@ class ForStatementExec( session: SparkSession) extends NonLeafStatementExec { private object ForState extends Enumeration { - val VariableAssignment, Body = Value + val VariableAssignment, Body, VariableCleanup = Value } private var state = ForState.VariableAssignment private var currRow = 0 @@ -677,6 +676,9 @@ class ForStatementExec( // (variableName -> variableExpression) private var variablesMap: Map[String, Expression] = Map() + // compound body used for dropping variables + private var dropVariablesExec: CompoundBodyExec = null + private var queryResult: Array[Row] = null private var isResultCacheValid = false private def cachedQueryResult(): Array[Row] = { @@ -694,8 +696,12 @@ class ForStatementExec( private lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { - override def hasNext: Boolean = - !interrupted && cachedQueryResult().length > 0 && currRow < cachedQueryResult().length + override def hasNext: Boolean = { + val resultSize = cachedQueryResult().length + val ret = state == ForState.VariableCleanup || + (!interrupted && resultSize > 0 && currRow < resultSize) + ret + } override def next(): CompoundStatementExec = state match { @@ -703,6 +709,7 @@ class ForStatementExec( variablesMap = createVariablesMapFromRow(currRow) if (!areVariablesDeclared) { + // create and execute declare var statements variablesMap.keys.toSeq .map(colName => createDeclareVarExec(colName, variablesMap(colName))) .foreach(declareVarExec => declareVarExec.buildDataFrame(session).collect()) @@ -725,33 +732,44 @@ class ForStatementExec( leaveStatementExec.hasBeenMatched = true } interrupted = true + // If this for statement encounters LEAVE, it will either not be executed ever + // again, or it will be reset before being executed. + // In either case, variables will not + // be dropped normally, from ForState.VariableCleanup, so we drop them here. dropVars() return retStmt case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => if (label.contains(iterStatementExec.label)) { iterStatementExec.hasBeenMatched = true + } else { + // if an outer loop is being iterated, this for statement will either not be + // executed ever again, or it will be reset before being executed. + // In either case, variables will not + // be dropped normally, from ForState.VariableCleanup, so we drop them here. + dropVars() } - currRow += 1 - state = ForState.VariableAssignment + switchStateFromBody() return retStmt case _ => } if (!body.getTreeIterator.hasNext) { - currRow += 1 - state = ForState.VariableAssignment - - // on final iteration, drop variables - if (currRow == cachedQueryResult().length) { - dropVars() - } + switchStateFromBody() } retStmt + + case ForState.VariableCleanup => + val ret = dropVariablesExec.getTreeIterator.next() + if (!dropVariablesExec.getTreeIterator.hasNext) { + state = ForState.VariableAssignment + } + ret } } /** - * Creates a Catalyst expression from Scala value. + * Creates a Catalyst expression from Scala value.
+ * See https://spark.apache.org/docs/latest/sql-ref-datatypes.html for Spark -> Scala mappings */ private def createExpressionFromValue(value: Any): Expression = value match { case m: Map[_, _] => @@ -759,15 +777,21 @@ class ForStatementExec( val mapArgs = m.keys.toSeq.flatMap { key => Seq(createExpressionFromValue(key), createExpressionFromValue(m(key))) } - CreateMap(mapArgs, false) + CreateMap(mapArgs, useStringTypeWhenEmpty = false) + + // structs match this case case s: Row => - // struct types match this case // arguments of CreateNamedStruct are in the format: (name1, val1, name2, val2, ...) val namedStructArgs = s.schema.names.toSeq.flatMap { colName => val valueExpression = createExpressionFromValue(s.getAs(colName)) Seq(Literal(colName), valueExpression) } CreateNamedStruct(namedStructArgs) + + // arrays match this case + case a: collection.Seq[_] => + val arrayArgs = a.toSeq.map(createExpressionFromValue(_)) + CreateArray(arrayArgs, useStringTypeWhenEmpty = false) case _ => Literal(value) } @@ -787,13 +811,24 @@ class ForStatementExec( variablesMap } - private def dropVars() = { + private def dropVars(): Unit = { variablesMap.keys.toSeq .map(colName => createDropVarExec(colName)) .foreach(dropVarExec => dropVarExec.buildDataFrame(session).collect()) areVariablesDeclared = false } + private def switchStateFromBody(): Unit = { + currRow += 1 + state = if (currRow < cachedQueryResult().length) ForState.VariableAssignment + else { + dropVariablesExec = new CompoundBodyExec( + variablesMap.keys.toSeq.map(colName => createDropVarExec(colName)) + ) + ForState.VariableCleanup + } + } + private def createDeclareVarExec(varName: String, variable: Expression): SingleStatementExec = { val defaultExpression = DefaultValueExpression(Literal(null, variable.dataType), "null") val declareVariable = CreateVariable( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 1588299fe0d9b..235dbfbbfd93c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -118,6 +118,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case leaveStmt: LeaveStatementExec => leaveStmt.label case iterateStmt: IterateStatementExec => iterateStmt.label case forStmt: ForStatementExec => forStmt.label.get + case _: SingleStatementExec => "SingleStatementExec" case _ => fail("Unexpected statement type") } @@ -717,7 +718,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( - "body" + "body", + "SingleStatementExec", // drop local var + "SingleStatementExec" // drop local var )) } @@ -738,7 +741,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "statement1", "statement2", "statement1", - "statement2" + "statement2", + "SingleStatementExec", // drop local var + "SingleStatementExec", // drop local var )) } @@ -778,8 +783,14 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi assert(statements === Seq( "body", "body", + "SingleStatementExec", // drop inner local var + "SingleStatementExec", // drop inner local var + "body", "body", - "body" + "SingleStatementExec", // drop inner local var + "SingleStatementExec", // drop inner local var + "SingleStatementExec", // drop outer local var + "SingleStatementExec", // drop outer local var )) } @@ -794,7 +805,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq - assert(statements === Seq("body")) + assert(statements === Seq( + "body", + "SingleStatementExec", // drop local var + )) } test("for statement no variable - enters body with multiple statements multiple times") { @@ -810,7 +824,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq - assert(statements === Seq("statement1", "statement2", "statement1", "statement2")) + assert(statements === Seq( + "statement1", "statement2", "statement1", "statement2", + "SingleStatementExec", // drop local var + )) } test("for statement no variable - empty result") { @@ -846,7 +863,13 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq - assert(statements === Seq("body", "body", "body", "body")) + assert(statements === Seq( + "body", "body", + "SingleStatementExec", // drop inner local var + "body", "body", + "SingleStatementExec", // drop inner local var + "SingleStatementExec", // drop outer local var + )) } test("for statement - iterate") { @@ -867,7 +890,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "statement1", "lbl1", "statement1", - "lbl1" + "lbl1", + "SingleStatementExec", // drop local var + "SingleStatementExec", // drop local var )) } @@ -897,6 +922,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi query = TestForStatementQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( + TestLeafStatement("outer_body"), new ForStatementExec( query = TestForStatementQuery(2, "intCol1", "query2"), variableName = Some("y"), @@ -914,10 +940,14 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( + "outer_body", "body1", "lbl1", + "outer_body", "body1", - "lbl1" + "lbl1", + "SingleStatementExec", // drop local var + "SingleStatementExec", // drop local var )) } @@ -963,7 +993,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq - assert(statements === Seq("statement1", "lbl1", "statement1", "lbl1")) + assert(statements === Seq( + "statement1", "lbl1", "statement1", "lbl1", + "SingleStatementExec", // drop local var + )) } test("for statement no variable - leave") { @@ -989,6 +1022,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi query = TestForStatementQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( + TestLeafStatement("outer_body"), new ForStatementExec( query = TestForStatementQuery(2, "intCol1", "query2"), variableName = None, @@ -1005,7 +1039,10 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq - assert(statements === Seq("body1", "lbl1", "body1", "lbl1")) + assert(statements === Seq( + "outer_body", "body1", "lbl1", "outer_body", "body1", "lbl1", + "SingleStatementExec", // drop local var + )) } test("for statement no variable - nested - leave outer loop") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 8890966e25d65..8b6a8f4fd83ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1548,13 +1548,34 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScriptText, expected) } - test("testetst") { - val sqlScript = "DECLARE my_map DEFAULT MAP(1,1);" - verifySqlScriptResult(sqlScript, Seq.empty[Seq[Row]]) + // todo: duplicate for non var tests, negative tests + test("for statement - enters body once") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | FOR row AS SELECT * FROM t DO + | SELECT row.intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(1)), // select row.intCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } } - // todo: complex types in for, better var names in tests - test("for test") { + test("for statement - enters body with multiple statements multiple times") { withTable("t") { val sqlScript = """ @@ -1562,12 +1583,12 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; | INSERT INTO t VALUES (1, 'first', 1.0); | INSERT INTO t VALUES (2, 'second', 2.0); - | FOR x AS SELECT * FROM t ORDER BY intCol DO - | SELECT x.intCol; + | FOR row AS SELECT * FROM t ORDER BY intCol DO + | SELECT row.intCol; | SELECT intCol; - | SELECT x.stringCol; + | SELECT row.stringCol; | SELECT stringCol; - | SELECT x.doubleCol; + | SELECT row.doubleCol; | SELECT doubleCol; | END FOR; |END @@ -1577,32 +1598,36 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // create table Seq.empty[Row], // insert Seq.empty[Row], // insert - Seq(Row(1)), // select x.intCol + Seq(Row(1)), // select row.intCol Seq(Row(1)), // select intCol - Seq(Row("first")), // select x.stringCol + Seq(Row("first")), // select row.stringCol Seq(Row("first")), // select stringCol - Seq(Row(1.0)), // select x.doubleCol + Seq(Row(1.0)), // select row.doubleCol Seq(Row(1.0)), // select doubleCol - Seq(Row(2)), // select x.intCol + Seq(Row(2)), // select row.intCol Seq(Row(2)), // select intCol - Seq(Row("second")), // select x.stringCol + Seq(Row("second")), // select row.stringCol Seq(Row("second")), // select stringCol - Seq(Row(2.0)), // select x.doubleCol - Seq(Row(2.0)) // select doubleCol + Seq(Row(2.0)), // select row.doubleCol + Seq(Row(2.0)), // select doubleCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var ) verifySqlScriptResult(sqlScript, expected) } } - test("for test complex types") { + test("for test - map, struct, array") { withTable("t") { val sqlScript = """ |BEGIN - | CREATE TABLE t (int_column INT, map_column MAP>, struct_column STRUCT, array_column ARRAY); + | CREATE TABLE t (int_column INT, map_column MAP, struct_column STRUCT, array_column ARRAY); | INSERT INTO t VALUES - | (1, MAP('a', MAP(1, 10)), STRUCT('John', 25), ARRAY('apple', 'banana')), - | (2, MAP('b', MAP(2, 20)), STRUCT('Jane', 30), ARRAY('apple', 'banana')); + | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')), + | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear')); | FOR row AS SELECT * FROM t ORDER BY int_column DO | SELECT row.map_column; | SELECT map_column; @@ -1617,61 +1642,122 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // create table Seq.empty[Row], // insert - Seq(Row(Map("a" -> Map(1 -> 10)))), // select row.map_column - Seq(Row(Map("a" -> Map(1 -> 10)))), // select map_column + Seq(Row(Map("a" -> 1))), // select row.map_column + Seq(Row(Map("a" -> 1))), // select map_column Seq(Row(Row("John", 25))), // select row.struct_column Seq(Row(Row("John", 25))), // select struct_column - Seq(Row(Array("apple", "banana"))), // select row.array_column - Seq(Row(Array("apple", "banana"))), // select array_column - Seq(Row(Map("b" -> Map(2 -> 20)))), // select row.map_column - Seq(Row(Map("b" -> Map(2 -> 20)))), // select map_column + Seq(Row(Array("apricot", "quince"))), // select row.array_column + Seq(Row(Array("apricot", "quince"))), // select array_column + Seq(Row(Map("b" -> 2))), // select row.map_column + Seq(Row(Map("b" -> 2))), // select map_column Seq(Row(Row("Jane", 30))), // select row.struct_column Seq(Row(Row("Jane", 30))), // select struct_column - Seq(Row(Array("apple", "banana"))), // select row.array_column - Seq(Row(Array("apple", "banana"))) // select array_column + Seq(Row(Array("plum", "pear"))), // select row.array_column + Seq(Row(Array("plum", "pear"))), // select array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for test - nested struct") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, struct_column STRUCT>>); + | INSERT INTO t VALUES + | (1, STRUCT(1, STRUCT(STRUCT("one")))), + | (2, STRUCT(2, STRUCT(STRUCT("two")))); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.struct_column; + | SELECT struct_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Row(1, Row(Row("one"))))), // select row.struct_column + Seq(Row(Row(1, Row(Row("one"))))), // select struct_column + Seq(Row(Row(2, Row(Row("two"))))), // select row.struct_column + Seq(Row(Row(2, Row(Row("two"))))), // select struct_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var ) verifySqlScriptResult(sqlScript, expected) } } - // test("for test complex types") { - // withTable("t") { - // val sqlScript = - // """ - // |BEGIN - // | CREATE TABLE t (int_column INT, map_column MAP, struct_column STRUCT, array_column ARRAY); - // | INSERT INTO t VALUES - // | (1, MAP('a', 1, 'b', 2), STRUCT('John', 25), ARRAY('apple', 'banana')), - // | (2, MAP('c', 3, 'd', 4), STRUCT('Jane', 30), ARRAY('orange', 'grape')), - // | (3, MAP('e', 5, 'f', 6), STRUCT('Bob', 35), ARRAY('pear', 'peach')); - // | FOR row AS SELECT * FROM t ORDER BY int_column DO - // | SELECT row.map_column; - // | SELECT row.struct_column; - // | SELECT row.array_column; - // | END FOR; - // |END - // |""".stripMargin - // - // val expected = Seq( - // Seq.empty[Row], // create table - // Seq.empty[Row], // insert - // Seq.empty[Row], // insert - // Seq.empty[Row], // declare x - // Seq.empty[Row], // set x to row 0 - // Seq(Row(1)), // select intCol - // Seq(Row("first")), // select stringCol - // Seq(Row(1.0)), // select doubleCol - // Seq.empty[Row], // drop x - // Seq.empty[Row], // declare x - // Seq.empty[Row], // set x to row 1 - // Seq(Row(2)), // select intCol - // Seq(Row("second")), // select stringCol - // Seq(Row(2.0)), // select doubleCol - // Seq.empty[Row] // drop x - // ) - // verifySqlScriptResult(sqlScript, expected) - // } - // } + test("for test - nested map") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP>>); + | INSERT INTO t VALUES + | (1, MAP('a', MAP(1, MAP(false, 10)))), + | (2, MAP('b', MAP(2, MAP(true, 20)))); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.map_column; + | SELECT map_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select row.map_column + Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select map_column + Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select row.map_column + Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for test - nested array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, array_column ARRAY>>); + | INSERT INTO t VALUES + | (1, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3, 4)), ARRAY(ARRAY(5, 6)))), + | (2, ARRAY(ARRAY(ARRAY(7, 8), ARRAY(9, 10)), ARRAY(ARRAY(11, 12)))); + | FOR row AS SELECT * FROM t ORDER BY int_column DO + | SELECT row.array_column; + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // row.array_column + Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // array_column + Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // row.array_column + Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } test("for test empty result") { withTable("t") { @@ -1679,8 +1765,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { """ |BEGIN | CREATE TABLE t (intCol INT) using parquet; - | FOR x AS SELECT * FROM t ORDER BY intCol DO - | SELECT x.intCol; + | FOR row AS SELECT * FROM t ORDER BY intCol DO + | SELECT row.intCol; | END FOR; |END |""".stripMargin @@ -1698,15 +1784,13 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { """ |BEGIN | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; - | INSERT INTO t VALUES (1, 'first'); - | INSERT INTO t VALUES (2, 'second'); - | INSERT INTO t VALUES (3, 'third'); - | INSERT INTO t VALUES (4, 'fourth'); + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); | | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO | IF x.intCol = 2 THEN | ITERATE lbl; | END IF; + | SELECT stringCol; | SELECT x.stringCol; | END FOR; |END @@ -1715,12 +1799,15 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // create table Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // insert Seq(Row("first")), // select stringCol + Seq(Row("first")), // select x.stringCol Seq(Row("third")), // select stringCol - Seq(Row("fourth")) // select stringCol + Seq(Row("third")), // select x.stringCol + Seq(Row("fourth")), // select stringCol + Seq(Row("fourth")), // select x.stringCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -1732,15 +1819,13 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { """ |BEGIN | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; - | INSERT INTO t VALUES (1, 'first'); - | INSERT INTO t VALUES (2, 'second'); - | INSERT INTO t VALUES (3, 'third'); - | INSERT INTO t VALUES (4, 'fourth'); + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); | | lbl: FOR x AS SELECT * FROM t ORDER BY intCol DO | IF x.intCol = 3 THEN | LEAVE lbl; | END IF; + | SELECT stringCol; | SELECT x.stringCol; | END FOR; |END @@ -1749,11 +1834,10 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val expected = Seq( Seq.empty[Row], // create table Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq.empty[Row], // insert Seq(Row("first")), // select stringCol - Seq(Row("second")) // select stringCol + Seq(Row("first")), // select x.stringCol + Seq(Row("second")), // select stringCol + Seq(Row("second")) // select x.stringCol ) verifySqlScriptResult(sqlScript, expected) } @@ -1764,58 +1848,35 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val sqlScript = """ |BEGIN - | DECLARE i = 0; + | DECLARE cnt = 0; | CREATE TABLE t (intCol INT) using parquet; | INSERT INTO t VALUES (0); - | WHILE i < 2 DO - | SET i = i + 1; + | WHILE cnt < 2 DO + | SET cnt = cnt + 1; | FOR x AS SELECT * FROM t ORDER BY intCol DO | SELECT x.intCol; | END FOR; - | INSERT INTO t VALUES (i); + | INSERT INTO t VALUES (cnt); | END WHILE; |END |""".stripMargin val expected = Seq( - Seq.empty[Row], // declare i + Seq.empty[Row], // declare cnt Seq.empty[Row], // create table Seq.empty[Row], // insert - Seq.empty[Row], // set i + Seq.empty[Row], // set cnt Seq(Row(0)), // select intCol Seq.empty[Row], // insert - Seq.empty[Row], // set i + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt Seq(Row(0)), // select intCol Seq(Row(1)), // select intCol Seq.empty[Row], // insert - Seq.empty[Row] // drop i - ) - verifySqlScriptResult(sqlScript, expected) - } - } - - test("for test no variable") { - withTable("t") { - val sqlScript = - """ - |BEGIN - | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; - | INSERT INTO t VALUES (1, 'first', 1.0); - | INSERT INTO t VALUES (2, 'second', 2.0); - | FOR SELECT * FROM t ORDER BY intCol DO - | SELECT intCol; - | SELECT stringCol; - | SELECT doubleCol; - | END FOR; - |END - |""".stripMargin - - val expected = Seq( - Seq.empty[Row], // create table - Seq.empty[Row], // insert - Seq.empty[Row], // insert - Seq(Row(1)), // select 1 - Seq(Row(1)), // select 1 + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt ) verifySqlScriptResult(sqlScript, expected) } From 49017e4070d61682d9d312811870ed2f0fa41869 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Mon, 18 Nov 2024 18:39:01 +0100 Subject: [PATCH 25/39] add nested tests --- .../scripting/SqlScriptingExecutionNode.scala | 10 +- .../SqlScriptingInterpreterSuite.scala | 212 +++++++++++++++++- 2 files changed, 209 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index aa558b2e9e0dd..5e2b0a430d40e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -698,9 +698,8 @@ class ForStatementExec( new Iterator[CompoundStatementExec] { override def hasNext: Boolean = { val resultSize = cachedQueryResult().length - val ret = state == ForState.VariableCleanup || + state == ForState.VariableCleanup || (!interrupted && resultSize > 0 && currRow < resultSize) - ret } override def next(): CompoundStatementExec = state match { @@ -732,7 +731,7 @@ class ForStatementExec( leaveStatementExec.hasBeenMatched = true } interrupted = true - // If this for statement encounters LEAVE, it will either not be executed ever + // If this for statement encounters LEAVE, it will either not be executed // again, or it will be reset before being executed. // In either case, variables will not // be dropped normally, from ForState.VariableCleanup, so we drop them here. @@ -743,7 +742,7 @@ class ForStatementExec( iterStatementExec.hasBeenMatched = true } else { // if an outer loop is being iterated, this for statement will either not be - // executed ever again, or it will be reset before being executed. + // executed again, or it will be reset before being executed. // In either case, variables will not // be dropped normally, from ForState.VariableCleanup, so we drop them here. dropVars() @@ -761,6 +760,7 @@ class ForStatementExec( case ForState.VariableCleanup => val ret = dropVariablesExec.getTreeIterator.next() if (!dropVariablesExec.getTreeIterator.hasNext) { + // stops execution, as at this point currRow == resultSize state = ForState.VariableAssignment } ret @@ -862,6 +862,8 @@ class ForStatementExec( currRow = 0 variablesMap = Map() areVariablesDeclared = false + dropVariablesExec = null + interrupted = false body.reset() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 8b6a8f4fd83ea..9b0bd6a0c57ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1548,7 +1548,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScriptText, expected) } - // todo: duplicate for non var tests, negative tests + // todo: duplicate for non var tests, negative tests, + // for statement nested, nested leave, nested iterate test("for statement - enters body once") { withTable("t") { val sqlScript = @@ -1619,7 +1620,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } - test("for test - map, struct, array") { + test("for statement - map, struct, array") { withTable("t") { val sqlScript = """ @@ -1664,7 +1665,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } - test("for test - nested struct") { + test("for statement - nested struct") { withTable("t") { val sqlScript = """ @@ -1696,7 +1697,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } - test("for test - nested map") { + test("for statement - nested map") { withTable("t") { val sqlScript = """ @@ -1727,7 +1728,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } - test("for test - nested array") { + test("for statement - nested array") { withTable("t") { val sqlScript = """ @@ -1759,7 +1760,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } - test("for test empty result") { + test("for statement empty result") { withTable("t") { val sqlScript = """ @@ -1778,7 +1779,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } - test("for test iterate") { + test("for statement iterate") { withTable("t") { val sqlScript = """ @@ -1813,7 +1814,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } - test("for test leave") { + test("for statement leave") { withTable("t") { val sqlScript = """ @@ -1843,7 +1844,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } - test("for test nested in while") { + test("for statement - nested - in while") { withTable("t") { val sqlScript = """ @@ -1881,4 +1882,197 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScript, expected) } } + + test("for statement - nested - in other for") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | FOR x as SELECT * FROM t ORDER BY intCol DO + | FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT x.intCol; + | SELECT intCol; + | SELECT y.intCol2; + | SELECT intCol2; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(0)), // select x.intCol + Seq(Row(0)), // select intCol + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(0)), // select x.intCol + Seq(Row(0)), // select intCol + Seq(Row(2)), // select y.intCol2 + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq(Row(1)), // select x.intCol + Seq(Row(1)), // select intCol + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(1)), // select x.intCol + Seq(Row(1)), // select intCol + Seq(Row(2)), // select y.intCol2 + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + // ignored until loops are fixed to support empty bodies + ignore("for statement - nested - empty result set") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | REPEAT + | FOR x AS SELECT * FROM t ORDER BY intCol DO + | SELECT x.intCol; + | END FOR; + | UNTIL 1 = 1 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - iterate outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT y.intCol2; + | SELECT intCol2; + | ITERATE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - leave outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT y.intCol2; + | SELECT intCol2; + | LEAVE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - nested - leave inner loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR x as SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR y AS SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT y.intCol2; + | SELECT intCol2; + | LEAVE lbl2; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select y.intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row], // drop outer var + Seq.empty[Row], // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } } From 9a2f5fafc22f10f20baf706c8586f447cdb1ac73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 19 Nov 2024 13:07:53 +0100 Subject: [PATCH 26/39] add tests for no variables variant of for --- .../SqlScriptingInterpreterSuite.scala | 464 +++++++++++++++++- 1 file changed, 462 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 9b0bd6a0c57ea..30a270294eaa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1548,8 +1548,6 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScriptText, expected) } - // todo: duplicate for non var tests, negative tests, - // for statement nested, nested leave, nested iterate test("for statement - enters body once") { withTable("t") { val sqlScript = @@ -2075,4 +2073,466 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(sqlScript, expected) } } + + test("for statement - no variable - enters body once") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | FOR SELECT * FROM t DO + | SELECT intCol; + | SELECT stringCol; + | SELECT doubleCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(1)), // select intCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select doubleCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - enters body with multiple statements multiple times") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING, doubleCol DOUBLE) using parquet; + | INSERT INTO t VALUES (1, 'first', 1.0); + | INSERT INTO t VALUES (2, 'second', 2.0); + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | SELECT stringCol; + | SELECT doubleCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(1)), // select intCol + Seq(Row("first")), // select stringCol + Seq(Row(1.0)), // select doubleCol + Seq(Row(2)), // select intCol + Seq(Row("second")), // select stringCol + Seq(Row(2.0)), // select doubleCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - map, struct, array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP, struct_column STRUCT, array_column ARRAY); + | INSERT INTO t VALUES + | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')), + | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear')); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT map_column; + | SELECT struct_column; + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> 1))), // select map_column + Seq(Row(Row("John", 25))), // select struct_column + Seq(Row(Array("apricot", "quince"))), // select array_column + Seq(Row(Map("b" -> 2))), // select map_column + Seq(Row(Row("Jane", 30))), // select struct_column + Seq(Row(Array("plum", "pear"))), // select array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested struct") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, struct_column STRUCT>>); + | INSERT INTO t VALUES + | (1, STRUCT(1, STRUCT(STRUCT("one")))), + | (2, STRUCT(2, STRUCT(STRUCT("two")))); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT struct_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Row(1, Row(Row("one"))))), // select struct_column + Seq(Row(Row(2, Row(Row("two"))))), // select struct_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested map") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (int_column INT, map_column MAP>>); + | INSERT INTO t VALUES + | (1, MAP('a', MAP(1, MAP(false, 10)))), + | (2, MAP('b', MAP(2, MAP(true, 20)))); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT map_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select map_column + Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested array") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t + | (int_column INT, array_column ARRAY>>); + | INSERT INTO t VALUES + | (1, ARRAY(ARRAY(ARRAY(1, 2), ARRAY(3, 4)), ARRAY(ARRAY(5, 6)))), + | (2, ARRAY(ARRAY(ARRAY(7, 8), ARRAY(9, 10)), ARRAY(ARRAY(11, 12)))); + | FOR SELECT * FROM t ORDER BY int_column DO + | SELECT array_column; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // array_column + Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - empty result") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - iterate") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR SELECT * FROM t ORDER BY intCol DO + | IF intCol = 2 THEN + | ITERATE lbl; + | END IF; + | SELECT stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("third")), // select stringCol + Seq(Row("fourth")), // select stringCol + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - leave") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT, stringCol STRING) using parquet; + | INSERT INTO t VALUES (1, 'first'), (2, 'second'), (3, 'third'), (4, 'fourth'); + | + | lbl: FOR SELECT * FROM t ORDER BY intCol DO + | IF intCol = 3 THEN + | LEAVE lbl; + | END IF; + | SELECT stringCol; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq(Row("first")), // select stringCol + Seq(Row("second")), // select stringCol + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - in while") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE cnt = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (0); + | WHILE cnt < 2 DO + | SET cnt = cnt + 1; + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | END FOR; + | INSERT INTO t VALUES (cnt); + | END WHILE; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - in other for") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | FOR SELECT * FROM t ORDER BY intCol DO + | FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol; + | SELECT intCol2; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(0)), // select intCol + Seq(Row(3)), // select intCol2 + Seq(Row(0)), // select intCol + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq(Row(1)), // select intCol + Seq(Row(3)), // select intCol2 + Seq(Row(1)), // select intCol + Seq(Row(2)), // select intCol2 + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + // ignored until loops are fixed to support empty bodies + ignore("for statement - no variable - nested - empty result set") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | REPEAT + | FOR SELECT * FROM t ORDER BY intCol DO + | SELECT intCol; + | END FOR; + | UNTIL 1 = 1 + | END REPEAT; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare cnt + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row], // set cnt + Seq(Row(0)), // select intCol + Seq(Row(1)), // select intCol + Seq.empty[Row], // insert + Seq.empty[Row], // drop local var + Seq.empty[Row] // drop cnt + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - iterate outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol2; + | ITERATE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row] // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - leave outer loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol2; + | LEAVE lbl1; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select intCol2 + ) + verifySqlScriptResult(sqlScript, expected) + } + } + + test("for statement - no variable - nested - leave inner loop") { + withTable("t", "t2") { + val sqlScript = + """ + |BEGIN + | CREATE TABLE t (intCol INT) using parquet; + | CREATE TABLE t2 (intCol2 INT) using parquet; + | INSERT INTO t VALUES (0), (1); + | INSERT INTO t2 VALUES (2), (3); + | lbl1: FOR SELECT * FROM t ORDER BY intCol DO + | lbl2: FOR SELECT * FROM t2 ORDER BY intCol2 DESC DO + | SELECT intCol2; + | LEAVE lbl2; + | SELECT 1; + | END FOR; + | END FOR; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // create table + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // insert + Seq(Row(3)), // select intCol2 + Seq(Row(3)), // select intCol2 + Seq.empty[Row], // drop outer var + ) + verifySqlScriptResult(sqlScript, expected) + } + } } From cc632c3fc7af2e395d3b87dcdd0731febe814b0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 19 Nov 2024 15:00:00 +0100 Subject: [PATCH 27/39] clean up --- .../logical/SqlScriptingLogicalPlans.scala | 8 ++-- .../parser/SqlScriptingParserSuite.scala | 14 +++--- .../scripting/SqlScriptingExecutionNode.scala | 34 +++++++------- .../SqlScriptingExecutionNodeSuite.scala | 46 +++++++++---------- 4 files changed, 52 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index 63e919088ece4..e7eec3c6f1feb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -270,10 +270,10 @@ case class LoopStatement( /** * Logical operator for FOR statement. - * @param query Query which is executed once, then it's result is iterated on, row by row - * @param variableName Name of variable which is used to access the current row during iteration - * @param body Compound body is a collection of statements that are executed once for each row in - * the result set of the query + * @param query Query which is executed once, then it's result set is iterated on, row by row. + * @param variableName Name of variable which is used to access the current row during iteration. + * @param body Compound body is a collection of statements that are executed for each row in + * the result set of the query. * @param label An optional label for the loop which is unique amongst all labels for statements * within which the FOR statement is contained. * If an end label is specified it must match the beginning label. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index cd7df6255ffe0..c6b80b4e7403f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -1961,7 +1961,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } - test("for statement no label") { + test("for statement - no label") { val sqlScriptText = """ |BEGIN @@ -1987,7 +1987,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.isDefined) } - test("for statement with complex subquery") { + test("for statement - with complex subquery") { val sqlScriptText = """ |BEGIN @@ -2015,7 +2015,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } - test("nested for statement") { + test("for statement - nested") { val sqlScriptText = """ |BEGIN @@ -2052,7 +2052,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { head.asInstanceOf[SingleStatement].getText == "SELECT i + j") } - test("for statement no variable") { + test("for statement - no variable") { val sqlScriptText = """ |BEGIN @@ -2077,7 +2077,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } - test("for statement no label no variable") { + test("for statement - no variable - no label") { val sqlScriptText = """ |BEGIN @@ -2103,7 +2103,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.isDefined) } - test("for statement with complex subquery no variable") { + test("for statement - no variable - with complex subquery") { val sqlScriptText = """ |BEGIN @@ -2131,7 +2131,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(forStmt.label.contains("lbl")) } - test("nested for statement no variable") { + test("for statement - no variable - nested") { val sqlScriptText = """ |BEGIN diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 5e2b0a430d40e..54f2e747588a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -125,8 +125,8 @@ class SingleStatementExec( var isExecuted = false /** - * Builds a DataFrame from the parsedPlan of this SingleStatementExec - * @param session The SparkSession on which the parsedPlan is built + * Builds a DataFrame from the parsedPlan of this SingleStatementExec. + * @param session The SparkSession used. * @return * The DataFrame. */ @@ -676,7 +676,7 @@ class ForStatementExec( // (variableName -> variableExpression) private var variablesMap: Map[String, Expression] = Map() - // compound body used for dropping variables + // compound body used for dropping variables while in ForState.VariableAssignment private var dropVariablesExec: CompoundBodyExec = null private var queryResult: Array[Row] = null @@ -690,22 +690,23 @@ class ForStatementExec( } /** - * Loop can be interrupted by LeaveStatementExec + * For can be interrupted by LeaveStatementExec */ private var interrupted: Boolean = false private lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = { val resultSize = cachedQueryResult().length - state == ForState.VariableCleanup || + (state == ForState.VariableCleanup && dropVariablesExec.getTreeIterator.hasNext) || (!interrupted && resultSize > 0 && currRow < resultSize) } override def next(): CompoundStatementExec = state match { case ForState.VariableAssignment => - variablesMap = createVariablesMapFromRow(currRow) + variablesMap = createVariablesMapFromRow(cachedQueryResult()(currRow)) if (!areVariablesDeclared) { // create and execute declare var statements @@ -714,9 +715,12 @@ class ForStatementExec( .foreach(declareVarExec => declareVarExec.buildDataFrame(session).collect()) areVariablesDeclared = true } + + // create and execute set var statements variablesMap.keys.toSeq .map(colName => createSetVarExec(colName, variablesMap(colName))) .foreach(setVarExec => setVarExec.buildDataFrame(session).collect()) + state = ForState.Body body.reset() next() @@ -758,17 +762,12 @@ class ForStatementExec( retStmt case ForState.VariableCleanup => - val ret = dropVariablesExec.getTreeIterator.next() - if (!dropVariablesExec.getTreeIterator.hasNext) { - // stops execution, as at this point currRow == resultSize - state = ForState.VariableAssignment - } - ret + dropVariablesExec.getTreeIterator.next() } } /** - * Creates a Catalyst expression from Scala value.
+ * Recursively creates a Catalyst expression from Scala value.
* See https://spark.apache.org/docs/latest/sql-ref-datatypes.html for Spark -> Scala mappings */ private def createExpressionFromValue(value: Any): Expression = value match { @@ -779,7 +778,7 @@ class ForStatementExec( } CreateMap(mapArgs, useStringTypeWhenEmpty = false) - // structs match this case + // structs and rows match this case case s: Row => // arguments of CreateNamedStruct are in the format: (name1, val1, name2, val2, ...) val namedStructArgs = s.schema.names.toSeq.flatMap { colName => @@ -795,8 +794,7 @@ class ForStatementExec( case _ => Literal(value) } - private def createVariablesMapFromRow(rowIndex: Int): Map[String, Expression] = { - val row = cachedQueryResult()(rowIndex) + private def createVariablesMapFromRow(row: Row): Map[String, Expression] = { var variablesMap = row.schema.names.toSeq.map { colName => colName -> createExpressionFromValue(row.getAs(colName)) }.toMap @@ -811,6 +809,9 @@ class ForStatementExec( variablesMap } + /** + * Create and immediately execute dropVariable exec nodes for all variables in variablesMap. + */ private def dropVars(): Unit = { variablesMap.keys.toSeq .map(colName => createDropVarExec(colName)) @@ -822,6 +823,7 @@ class ForStatementExec( currRow += 1 state = if (currRow < cachedQueryResult().length) ForState.VariableAssignment else { + // create compound body for dropping nodes after execution is complete dropVariablesExec = new CompoundBodyExec( variablesMap.keys.toSeq.map(colName => createDropVarExec(colName)) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 235dbfbbfd93c..e0518d8e93b83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -93,7 +93,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi statement: LeafStatementExec): Boolean = evaluator.evaluateLoopBooleanCondition() } - case class TestForStatementQuery(numberOfRows: Int, columnName: String, description: String) + case class MockQuery(numberOfRows: Int, columnName: String, description: String) extends SingleStatementExec( DummyLogicalPlan(), Origin(startIndex = Some(0), stopIndex = Some(description.length)), @@ -709,7 +709,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - enters body once") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(1, "intCol", "query1"), + query = MockQuery(1, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for1"), @@ -727,7 +727,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - enters body with multiple statements multiple times") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -750,7 +750,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - empty result") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(0, "intCol", "query1"), + query = MockQuery(0, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))), label = Some("for1"), @@ -764,11 +764,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - nested") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol1", "query2"), + query = MockQuery(2, "intCol1", "query2"), variableName = Some("y"), body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for2"), @@ -797,7 +797,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - enters body once") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(1, "intCol", "query1"), + query = MockQuery(1, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for1"), @@ -814,7 +814,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - enters body with multiple statements multiple times") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -833,7 +833,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - empty result") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(0, "intCol", "query1"), + query = MockQuery(0, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))), label = Some("for1"), @@ -847,11 +847,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - nested") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol1", "query2"), + query = MockQuery(2, "intCol1", "query2"), variableName = None, body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for2"), @@ -875,7 +875,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - iterate") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -899,7 +899,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - leave") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -919,12 +919,12 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - nested - iterate outer loop") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( TestLeafStatement("outer_body"), new ForStatementExec( - query = TestForStatementQuery(2, "intCol1", "query2"), + query = MockQuery(2, "intCol1", "query2"), variableName = Some("y"), body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), @@ -954,11 +954,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement - nested - leave outer loop") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query2"), + query = MockQuery(2, "intCol", "query2"), variableName = Some("y"), body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), @@ -982,7 +982,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - iterate") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -1002,7 +1002,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - leave") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), @@ -1019,12 +1019,12 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - nested - iterate outer loop") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("outer_body"), new ForStatementExec( - query = TestForStatementQuery(2, "intCol1", "query2"), + query = MockQuery(2, "intCol1", "query2"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), @@ -1048,11 +1048,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("for statement no variable - nested - leave outer loop") { val iter = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol", "query1"), + query = MockQuery(2, "intCol", "query1"), variableName = None, body = new CompoundBodyExec(Seq( new ForStatementExec( - query = TestForStatementQuery(2, "intCol1", "query2"), + query = MockQuery(2, "intCol1", "query2"), variableName = None, body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), From d4447513fb44c23f6631c9f7d442850ebc8320b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 20 Nov 2024 16:25:47 +0100 Subject: [PATCH 28/39] update labels and tests --- .../sql/catalyst/parser/AstBuilder.scala | 11 +++-- .../logical/SqlScriptingLogicalPlans.scala | 17 ++++++- .../parser/SqlScriptingParserSuite.scala | 49 ++++--------------- 3 files changed, 33 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e8a6c4581d595..0eb017459a345 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -226,6 +226,8 @@ class AstBuilder extends DataTypeAstBuilder visitSearchedCaseStatementImpl(searchedCaseContext, labelCtx) case simpleCaseContext: SimpleCaseStatementContext => visitSimpleCaseStatementImpl(simpleCaseContext, labelCtx) + case forStatementContext: ForStatementContext => + visitForStatementImpl(forStatementContext, labelCtx) case stmt => visit(stmt).asInstanceOf[CompoundPlanStatement] } } else { @@ -347,15 +349,18 @@ class AstBuilder extends DataTypeAstBuilder RepeatStatement(condition, body, Some(labelText)) } - override def visitForStatement(ctx: ForStatementContext): ForStatement = { - val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel())) + private def visitForStatementImpl( + ctx: ForStatementContext, + labelCtx: SqlScriptingLabelContext): ForStatement = { + val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel())) val queryCtx = ctx.query() val query = withOrigin(queryCtx) { SingleStatement(visitQuery(queryCtx)) } val varName = Option(ctx.multipartIdentifier()).map(_.getText) - val body = visitCompoundBody(ctx.compoundBody()) + val body = visitCompoundBodyImpl(ctx.compoundBody(), None, allowVarDeclare = false, labelCtx) + labelCtx.exitLabeledScope(Option(ctx.beginLabel())) ForStatement(query, varName, body, Some(labelText)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index e7eec3c6f1feb..1d3ebe87a6b9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -283,4 +283,19 @@ case class ForStatement( query: SingleStatement, variableName: Option[String], body: CompoundBody, - label: Option[String]) extends CompoundPlanStatement + label: Option[String]) extends CompoundPlanStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq(query, body) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = { + assert(newChildren.length == 2) + ForStatement( + newChildren(0).asInstanceOf[SingleStatement], + variableName, + newChildren(1).asInstanceOf[CompoundBody], + label) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index c6b80b4e7403f..06d95b56c4f86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf @@ -39,37 +39,6 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { super.afterAll() } - // Tests - test("testtest") { - val sqlScriptText = "DECLARE my_map DEFAULT MAP('x', 0, 'y', 0);" - - val tree = parseScript(sqlScriptText) - assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[ForStatement]) - } - - test("testtesttest") { - val sqlScriptText = "DECLARE my_struct DEFAULT STRUCT<'x', 0, 1.2>;" - - val tree = parseScript(sqlScriptText) - assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[ForStatement]) - } - - test("initial for") { - val sqlScriptText = - """ - |BEGIN - | FOR x AS SELECT 1 DO - | SELECT 1; - | SELECT 2; - | END FOR; - |END""".stripMargin - val tree = parseScript(sqlScriptText) - assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[ForStatement]) - } - test("single select") { val sqlScriptText = "SELECT 1;" val statement = parsePlan(sqlScriptText) @@ -1944,7 +1913,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 1; | END FOR; |END""".stripMargin - val tree = parseScript(sqlScriptText) + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ForStatement]) @@ -1969,7 +1938,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 1; | END FOR; |END""".stripMargin - val tree = parseScript(sqlScriptText) + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ForStatement]) @@ -1996,7 +1965,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT x.c2; | END FOR; |END""".stripMargin - val tree = parseScript(sqlScriptText) + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ForStatement]) @@ -2025,7 +1994,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | END FOR lbl2; | END FOR lbl1; |END""".stripMargin - val tree = parseScript(sqlScriptText) + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ForStatement]) @@ -2060,7 +2029,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 1; | END FOR; |END""".stripMargin - val tree = parseScript(sqlScriptText) + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ForStatement]) @@ -2085,7 +2054,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 1; | END FOR; |END""".stripMargin - val tree = parseScript(sqlScriptText) + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ForStatement]) @@ -2112,7 +2081,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 2; | END FOR; |END""".stripMargin - val tree = parseScript(sqlScriptText) + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ForStatement]) @@ -2141,7 +2110,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | END FOR lbl2; | END FOR lbl1; |END""".stripMargin - val tree = parseScript(sqlScriptText) + val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) assert(tree.collection.head.isInstanceOf[ForStatement]) From f78e1fc30aaf8bb943bc5c864cd1f057693a8265 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 20 Nov 2024 16:28:14 +0100 Subject: [PATCH 29/39] nit --- .../spark/sql/catalyst/parser/SqlScriptingParserSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 06d95b56c4f86..57256e5b0d65f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -39,6 +39,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { super.afterAll() } + // Tests test("single select") { val sqlScriptText = "SELECT 1;" val statement = parsePlan(sqlScriptText) From 055a1b233df4d27bf9bc98e6bdba79c883d6bedb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 20 Nov 2024 16:41:13 +0100 Subject: [PATCH 30/39] add unique label tests --- .../parser/SqlScriptingParserSuite.scala | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 57256e5b0d65f..ab647f83b42a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -1822,6 +1822,25 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { parameters = Map("label" -> toSQLId("l_loop"))) } + test("unique label names: nested for loops") { + val sqlScriptText = + """BEGIN + |f_loop: FOR x AS SELECT 1 DO + | f_loop: FOR y AS SELECT 2 DO + | SELECT 1; + | END FOR; + |END FOR; + |END + """.stripMargin + val exception = intercept[SqlScriptingException] { + parsePlan(sqlScriptText).asInstanceOf[CompoundBody] + } + checkError( + exception = exception, + condition = "LABEL_ALREADY_EXISTS", + parameters = Map("label" -> toSQLId("f_loop"))) + } + test("unique label names: begin-end block on the same level") { val sqlScriptText = """BEGIN @@ -1857,10 +1876,13 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | SELECT 4; |UNTIL 1=1 |END REPEAT; + |lbl: FOR x AS SELECT 1 DO + | SELECT 5; + |END FOR; |END """.stripMargin val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] - assert(tree.collection.length == 4) + assert(tree.collection.length == 5) assert(tree.collection.head.isInstanceOf[CompoundBody]) assert(tree.collection.head.asInstanceOf[CompoundBody].label.get == "lbl") assert(tree.collection(1).isInstanceOf[WhileStatement]) @@ -1869,6 +1891,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(tree.collection(2).asInstanceOf[LoopStatement].label.get == "lbl") assert(tree.collection(3).isInstanceOf[RepeatStatement]) assert(tree.collection(3).asInstanceOf[RepeatStatement].label.get == "lbl") + assert(tree.collection(4).isInstanceOf[ForStatement]) + assert(tree.collection(4).asInstanceOf[ForStatement].label.get == "lbl") } test("unique label names: nested labeled scope statements") { @@ -1878,7 +1902,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { | lbl_1: WHILE 1=1 DO | lbl_2: LOOP | lbl_3: REPEAT - | SELECT 4; + | lbl_4: FOR x AS SELECT 1 DO + | SELECT 4; + | END FOR; | UNTIL 1=1 | END REPEAT; | END LOOP; @@ -1904,6 +1930,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { // Repeat statement val repeatStatement = loopStatement.body.collection.head.asInstanceOf[RepeatStatement] assert(repeatStatement.label.get == "lbl_3") + // For statement + val forStatement = repeatStatement.body.collection.head.asInstanceOf[ForStatement] + assert(forStatement.label.get == "lbl_4") } test("for statement") { From 955e79c897c7adced6784578d17381bed2a60880 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 21 Nov 2024 15:34:05 +0100 Subject: [PATCH 31/39] fix formatting and improve tests --- .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../scripting/SqlScriptingExecutionNode.scala | 1 + .../SqlScriptingExecutionNodeSuite.scala | 175 +++++++++--------- 3 files changed, 88 insertions(+), 92 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0eb017459a345..81b50344e2ffb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -350,8 +350,8 @@ class AstBuilder extends DataTypeAstBuilder } private def visitForStatementImpl( - ctx: ForStatementContext, - labelCtx: SqlScriptingLabelContext): ForStatement = { + ctx: ForStatementContext, + labelCtx: SqlScriptingLabelContext): ForStatement = { val labelText = labelCtx.enterLabeledScope(Option(ctx.beginLabel()), Option(ctx.endLabel())) val queryCtx = ctx.query() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 54f2e747588a0..ce3ee18c28b0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -791,6 +791,7 @@ class ForStatementExec( case a: collection.Seq[_] => val arrayArgs = a.toSeq.map(createExpressionFromValue(_)) CreateArray(arrayArgs, useStringTypeWhenEmpty = false) + case _ => Literal(value) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index e0518d8e93b83..c3feaf63aa07e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, OneRowRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{DropVariable, LeafNode, OneRowRelation, Project} import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -118,7 +118,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case leaveStmt: LeaveStatementExec => leaveStmt.label case iterateStmt: IterateStatementExec => iterateStmt.label case forStmt: ForStatementExec => forStmt.label.get - case _: SingleStatementExec => "SingleStatementExec" + case dropStmt: SingleStatementExec if dropStmt.parsedPlan.isInstanceOf[DropVariable] + => "DropVariable" case _ => fail("Unexpected statement type") } @@ -711,16 +712,16 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(1, "intCol", "query1"), variableName = Some("x"), - body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for1"), - session = spark + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "body", - "SingleStatementExec", // drop local var - "SingleStatementExec" // drop local var + "DropVariable", // drop for query var intCol + "DropVariable" // drop for loop var x )) } @@ -729,11 +730,11 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), - body = new CompoundBodyExec(Seq( - TestLeafStatement("statement1"), - TestLeafStatement("statement2"))), label = Some("for1"), - session = spark + session = spark, + body = new CompoundBodyExec( + Seq(TestLeafStatement("statement1"), TestLeafStatement("statement2")) + ) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq @@ -742,8 +743,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "statement2", "statement1", "statement2", - "SingleStatementExec", // drop local var - "SingleStatementExec", // drop local var + "DropVariable", // drop for query var intCol + "DropVariable", // drop for loop var x )) } @@ -752,9 +753,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(0, "intCol", "query1"), variableName = Some("x"), - body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))), label = Some("for1"), - session = spark + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq @@ -766,31 +767,31 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), + label = Some("for1"), + session = spark, body = new CompoundBodyExec(Seq( new ForStatementExec( query = MockQuery(2, "intCol1", "query2"), variableName = Some("y"), - body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for2"), - session = spark + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) ) - )), - label = Some("for1"), - session = spark + )) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "body", "body", - "SingleStatementExec", // drop inner local var - "SingleStatementExec", // drop inner local var + "DropVariable", // drop for query var intCol1 + "DropVariable", // drop for loop var y "body", "body", - "SingleStatementExec", // drop inner local var - "SingleStatementExec", // drop inner local var - "SingleStatementExec", // drop outer local var - "SingleStatementExec", // drop outer local var + "DropVariable", // drop for query var intCol1 + "DropVariable", // drop for loop var y + "DropVariable", // drop for query var intCol + "DropVariable", // drop for loop var x )) } @@ -799,15 +800,15 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(1, "intCol", "query1"), variableName = None, - body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for1"), - session = spark + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "body", - "SingleStatementExec", // drop local var + "DropVariable", // drop for query var intCol )) } @@ -816,17 +817,17 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = None, + label = Some("for1"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), - TestLeafStatement("statement2"))), - label = Some("for1"), - session = spark + TestLeafStatement("statement2"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "statement1", "statement2", "statement1", "statement2", - "SingleStatementExec", // drop local var + "DropVariable", // drop for query var intCol )) } @@ -835,9 +836,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(0, "intCol", "query1"), variableName = None, - body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))), label = Some("for1"), - session = spark + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body1"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq @@ -849,26 +850,26 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = None, + label = Some("for1"), + session = spark, body = new CompoundBodyExec(Seq( new ForStatementExec( query = MockQuery(2, "intCol1", "query2"), variableName = None, - body = new CompoundBodyExec(Seq(TestLeafStatement("body"))), label = Some("for2"), - session = spark + session = spark, + body = new CompoundBodyExec(Seq(TestLeafStatement("body"))) ) - )), - label = Some("for1"), - session = spark + )) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "body", "body", - "SingleStatementExec", // drop inner local var + "DropVariable", // drop for query var intCol1 "body", "body", - "SingleStatementExec", // drop inner local var - "SingleStatementExec", // drop outer local var + "DropVariable", // drop for query var intCol1 + "DropVariable", // drop for query var intCol )) } @@ -877,12 +878,12 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), + label = Some("lbl1"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), new IterateStatementExec("lbl1"), - TestLeafStatement("statement2"))), - label = Some("lbl1"), - session = spark + TestLeafStatement("statement2"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq @@ -891,8 +892,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "lbl1", "statement1", "lbl1", - "SingleStatementExec", // drop local var - "SingleStatementExec", // drop local var + "DropVariable", // drop for query var intCol + "DropVariable", // drop for loop var x )) } @@ -901,19 +902,16 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), + label = Some("lbl1"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), new LeaveStatementExec("lbl1"), - TestLeafStatement("statement2"))), - label = Some("lbl1"), - session = spark + TestLeafStatement("statement2"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq - assert(statements === Seq( - "statement1", - "lbl1" - )) + assert(statements === Seq("statement1", "lbl1")) } test("for statement - nested - iterate outer loop") { @@ -921,21 +919,21 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), + label = Some("lbl1"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("outer_body"), new ForStatementExec( query = MockQuery(2, "intCol1", "query2"), variableName = Some("y"), + label = Some("lbl2"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), new IterateStatementExec("lbl1"), - TestLeafStatement("body2"))), - label = Some("lbl2"), - session = spark + TestLeafStatement("body2"))) ) - )), - label = Some("lbl1"), - session = spark + )) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq @@ -946,8 +944,8 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "outer_body", "body1", "lbl1", - "SingleStatementExec", // drop local var - "SingleStatementExec", // drop local var + "DropVariable", // drop for query var intCol + "DropVariable", // drop for loop var x )) } @@ -956,27 +954,24 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = Some("x"), + label = Some("lbl1"), + session = spark, body = new CompoundBodyExec(Seq( new ForStatementExec( query = MockQuery(2, "intCol", "query2"), variableName = Some("y"), + label = Some("lbl2"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), new LeaveStatementExec("lbl1"), - TestLeafStatement("body2"))), - label = Some("lbl2"), - session = spark + TestLeafStatement("body2"))) ) - )), - label = Some("lbl1"), - session = spark + )) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq - assert(statements === Seq( - "body1", - "lbl1" - )) + assert(statements === Seq("body1", "lbl1")) } test("for statement no variable - iterate") { @@ -984,18 +979,18 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = None, + label = Some("lbl1"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), new IterateStatementExec("lbl1"), - TestLeafStatement("statement2"))), - label = Some("lbl1"), - session = spark + TestLeafStatement("statement2"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "statement1", "lbl1", "statement1", "lbl1", - "SingleStatementExec", // drop local var + "DropVariable", // drop for query var intCol )) } @@ -1004,12 +999,12 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = None, + label = Some("lbl1"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("statement1"), new LeaveStatementExec("lbl1"), - TestLeafStatement("statement2"))), - label = Some("lbl1"), - session = spark + TestLeafStatement("statement2"))) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq @@ -1021,27 +1016,27 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = None, + label = Some("lbl1"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("outer_body"), new ForStatementExec( query = MockQuery(2, "intCol1", "query2"), variableName = None, + label = Some("lbl2"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), new IterateStatementExec("lbl1"), - TestLeafStatement("body2"))), - label = Some("lbl2"), - session = spark + TestLeafStatement("body2"))) ) - )), - label = Some("lbl1"), - session = spark + )) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "outer_body", "body1", "lbl1", "outer_body", "body1", "lbl1", - "SingleStatementExec", // drop local var + "DropVariable", // drop for query var intCol )) } @@ -1050,20 +1045,20 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi new ForStatementExec( query = MockQuery(2, "intCol", "query1"), variableName = None, + label = Some("lbl1"), + session = spark, body = new CompoundBodyExec(Seq( new ForStatementExec( query = MockQuery(2, "intCol1", "query2"), variableName = None, + label = Some("lbl2"), + session = spark, body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), new LeaveStatementExec("lbl1"), TestLeafStatement("body2"))), - label = Some("lbl2"), - session = spark ) )), - label = Some("lbl1"), - session = spark ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq From c46750a139dbefffe868c58ae2b2c94d491c0bcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 21 Nov 2024 15:48:59 +0100 Subject: [PATCH 32/39] implement daneils suggestions --- .../sql/catalyst/parser/AstBuilder.scala | 30 +++++++++---------- .../logical/SqlScriptingLogicalPlans.scala | 10 ++----- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 81b50344e2ffb..56a55bd9ad208 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -369,28 +369,28 @@ class AstBuilder extends DataTypeAstBuilder ctx: RuleContext, label: String, isIterate: Boolean): Boolean = { ctx match { case c: BeginEndCompoundBlockContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) => - if (isIterate) { + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => if (isIterate) { throw SqlScriptingErrors.invalidIterateLabelUsageForCompound(CurrentOrigin.get, label) } true case c: WhileStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case c: RepeatStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case c: LoopStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case c: ForStatementContext - if Option(c.beginLabel()).isDefined && - c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) - => true + if Option(c.beginLabel()).exists { b => + b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + } => true case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index 1d3ebe87a6b9e..517bb1ead71f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -290,12 +290,8 @@ case class ForStatement( override def children: Seq[LogicalPlan] = Seq(query, body) override protected def withNewChildrenInternal( - newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = { - assert(newChildren.length == 2) - ForStatement( - newChildren(0).asInstanceOf[SingleStatement], - variableName, - newChildren(1).asInstanceOf[CompoundBody], - label) + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = newChildren match { + case IndexedSeq(query: SingleStatement, body: CompoundBody) => + ForStatement(query, variableName, body, label) } } From e485c99344f1059b59a8912ba0c1c7746fa50663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 21 Nov 2024 16:28:59 +0100 Subject: [PATCH 33/39] move isExecuted out of buildDataframe --- .../apache/spark/sql/scripting/SqlScriptingExecutionNode.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index ce3ee18c28b0b..06a1e905d4b9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -131,7 +131,6 @@ class SingleStatementExec( * The DataFrame. */ def buildDataFrame(session: SparkSession): DataFrame = { - isExecuted = true Dataset.ofRows(session, parsedPlan) } @@ -684,6 +683,7 @@ class ForStatementExec( private def cachedQueryResult(): Array[Row] = { if (!isResultCacheValid) { queryResult = query.buildDataFrame(session).collect() + query.isExecuted = true isResultCacheValid = true } queryResult From 112b86092e676c2812117c93f0ede0c5a0f1f02e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 21 Nov 2024 17:16:45 +0100 Subject: [PATCH 34/39] fix scalastyle --- .../SqlScriptingExecutionNodeSuite.scala | 22 ++++---- .../SqlScriptingInterpreterSuite.scala | 55 ++++++++++--------- 2 files changed, 40 insertions(+), 37 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index c3feaf63aa07e..454cee5ea73de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -744,7 +744,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "statement1", "statement2", "DropVariable", // drop for query var intCol - "DropVariable", // drop for loop var x + "DropVariable" // drop for loop var x )) } @@ -791,7 +791,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "DropVariable", // drop for query var intCol1 "DropVariable", // drop for loop var y "DropVariable", // drop for query var intCol - "DropVariable", // drop for loop var x + "DropVariable" // drop for loop var x )) } @@ -808,7 +808,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "body", - "DropVariable", // drop for query var intCol + "DropVariable" // drop for query var intCol )) } @@ -827,7 +827,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "statement1", "statement2", "statement1", "statement2", - "DropVariable", // drop for query var intCol + "DropVariable" // drop for query var intCol )) } @@ -869,7 +869,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "DropVariable", // drop for query var intCol1 "body", "body", "DropVariable", // drop for query var intCol1 - "DropVariable", // drop for query var intCol + "DropVariable" // drop for query var intCol )) } @@ -893,7 +893,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "statement1", "lbl1", "DropVariable", // drop for query var intCol - "DropVariable", // drop for loop var x + "DropVariable" // drop for loop var x )) } @@ -945,7 +945,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi "body1", "lbl1", "DropVariable", // drop for query var intCol - "DropVariable", // drop for loop var x + "DropVariable" // drop for loop var x )) } @@ -990,7 +990,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "statement1", "lbl1", "statement1", "lbl1", - "DropVariable", // drop for query var intCol + "DropVariable" // drop for query var intCol )) } @@ -1036,7 +1036,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq( "outer_body", "body1", "lbl1", "outer_body", "body1", "lbl1", - "DropVariable", // drop for query var intCol + "DropVariable" // drop for query var intCol )) } @@ -1056,9 +1056,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi body = new CompoundBodyExec(Seq( TestLeafStatement("body1"), new LeaveStatementExec("lbl1"), - TestLeafStatement("body2"))), + TestLeafStatement("body2"))) ) - )), + )) ) )).getTreeIterator val statements = iter.map(extractStatementValue).toSeq diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 30a270294eaa3..2dcbf06c7ecac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1568,7 +1568,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -1612,7 +1612,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -1623,7 +1623,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val sqlScript = """ |BEGIN - | CREATE TABLE t (int_column INT, map_column MAP, struct_column STRUCT, array_column ARRAY); + | CREATE TABLE t (int_column INT, map_column MAP, + | struct_column STRUCT, array_column ARRAY); | INSERT INTO t VALUES | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')), | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear')); @@ -1657,7 +1658,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -1669,7 +1670,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { """ |BEGIN | CREATE TABLE t - | (int_column INT, struct_column STRUCT>>); + | (int_column INT, + | struct_column STRUCT>>); | INSERT INTO t VALUES | (1, STRUCT(1, STRUCT(STRUCT("one")))), | (2, STRUCT(2, STRUCT(STRUCT("two")))); @@ -1689,7 +1691,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(Row(2, Row(Row("two"))))), // select struct_column Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -1720,7 +1722,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -1752,7 +1754,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -1771,7 +1773,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |""".stripMargin val expected = Seq( - Seq.empty[Row], // create table + Seq.empty[Row] // create table ) verifySqlScriptResult(sqlScript, expected) } @@ -1806,7 +1808,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row("fourth")), // select x.stringCol Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2032,7 +2034,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // insert Seq.empty[Row], // insert Seq(Row(3)), // select y.intCol2 - Seq(Row(3)), // select intCol2 + Seq(Row(3)) // select intCol2 ) verifySqlScriptResult(sqlScript, expected) } @@ -2068,7 +2070,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(3)), // select y.intCol2 Seq(Row(3)), // select intCol2 Seq.empty[Row], // drop outer var - Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var ) verifySqlScriptResult(sqlScript, expected) } @@ -2097,7 +2099,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(1.0)), // select doubleCol Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2131,7 +2133,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(2.0)), // select doubleCol Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2142,7 +2144,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val sqlScript = """ |BEGIN - | CREATE TABLE t (int_column INT, map_column MAP, struct_column STRUCT, array_column ARRAY); + | CREATE TABLE t (int_column INT, map_column MAP, + | struct_column STRUCT, array_column ARRAY); | INSERT INTO t VALUES | (1, MAP('a', 1), STRUCT('John', 25), ARRAY('apricot', 'quince')), | (2, MAP('b', 2), STRUCT('Jane', 30), ARRAY('plum', 'pear')); @@ -2166,7 +2169,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2177,8 +2180,8 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { val sqlScript = """ |BEGIN - | CREATE TABLE t - | (int_column INT, struct_column STRUCT>>); + | CREATE TABLE t (int_column INT, + | struct_column STRUCT>>); | INSERT INTO t VALUES | (1, STRUCT(1, STRUCT(STRUCT("one")))), | (2, STRUCT(2, STRUCT(STRUCT("two")))); @@ -2194,7 +2197,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(Row(1, Row(Row("one"))))), // select struct_column Seq(Row(Row(2, Row(Row("two"))))), // select struct_column Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2221,7 +2224,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(Map("a" -> Map(1 -> Map(false -> 10))))), // select map_column Seq(Row(Map("b" -> Map(2 -> Map(true -> 20))))), // select map_column Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2249,7 +2252,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row(Seq(Seq(Seq(1, 2), Seq(3, 4)), Seq(Seq(5, 6))))), // array_column Seq(Row(Array(Seq(Seq(7, 8), Seq(9, 10)), Seq(Seq(11, 12))))), // array_column Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2268,7 +2271,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { |""".stripMargin val expected = Seq( - Seq.empty[Row], // create table + Seq.empty[Row] // create table ) verifySqlScriptResult(sqlScript, expected) } @@ -2298,7 +2301,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq(Row("third")), // select stringCol Seq(Row("fourth")), // select stringCol Seq.empty[Row], // drop local var - Seq.empty[Row], // drop local var + Seq.empty[Row] // drop local var ) verifySqlScriptResult(sqlScript, expected) } @@ -2325,7 +2328,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // create table Seq.empty[Row], // insert Seq(Row("first")), // select stringCol - Seq(Row("second")), // select stringCol + Seq(Row("second")) // select stringCol ) verifySqlScriptResult(sqlScript, expected) } @@ -2498,7 +2501,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // create table Seq.empty[Row], // insert Seq.empty[Row], // insert - Seq(Row(3)), // select intCol2 + Seq(Row(3)) // select intCol2 ) verifySqlScriptResult(sqlScript, expected) } @@ -2530,7 +2533,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { Seq.empty[Row], // insert Seq(Row(3)), // select intCol2 Seq(Row(3)), // select intCol2 - Seq.empty[Row], // drop outer var + Seq.empty[Row] // drop outer var ) verifySqlScriptResult(sqlScript, expected) } From 54271d0dd1c3d5898aef3cb9df0fd57996c3a43c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 21 Nov 2024 17:19:02 +0100 Subject: [PATCH 35/39] formatting --- .../sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index 517bb1ead71f1..4faf1f5d26672 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -290,7 +290,7 @@ case class ForStatement( override def children: Seq[LogicalPlan] = Seq(query, body) override protected def withNewChildrenInternal( - newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = newChildren match { + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = newChildren match { case IndexedSeq(query: SingleStatement, body: CompoundBody) => ForStatement(query, variableName, body, label) } From 223612e9059e06697a86102b00ff250501b9b9af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Fri, 22 Nov 2024 17:49:45 +0100 Subject: [PATCH 36/39] refactor collect() to toLocalIterator() --- .../scripting/SqlScriptingExecutionNode.scala | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index 06a1e905d4b9c..f7cefac8d1068 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType +import java.util + /** * Trait for all SQL scripting execution nodes used during interpretation phase. */ @@ -668,7 +670,6 @@ class ForStatementExec( val VariableAssignment, Body, VariableCleanup = Value } private var state = ForState.VariableAssignment - private var currRow = 0 private var areVariablesDeclared = false // map of all variables created internally by the for statement @@ -678,11 +679,11 @@ class ForStatementExec( // compound body used for dropping variables while in ForState.VariableAssignment private var dropVariablesExec: CompoundBodyExec = null - private var queryResult: Array[Row] = null + private var queryResult: util.Iterator[Row] = _ private var isResultCacheValid = false - private def cachedQueryResult(): Array[Row] = { + private def cachedQueryResult(): util.Iterator[Row] = { if (!isResultCacheValid) { - queryResult = query.buildDataFrame(session).collect() + queryResult = query.buildDataFrame(session).toLocalIterator() query.isExecuted = true isResultCacheValid = true } @@ -697,16 +698,16 @@ class ForStatementExec( private lazy val treeIterator: Iterator[CompoundStatementExec] = new Iterator[CompoundStatementExec] { - override def hasNext: Boolean = { - val resultSize = cachedQueryResult().length - (state == ForState.VariableCleanup && dropVariablesExec.getTreeIterator.hasNext) || - (!interrupted && resultSize > 0 && currRow < resultSize) - } + override def hasNext: Boolean = !interrupted && (state match { + case ForState.VariableAssignment => cachedQueryResult().hasNext + case ForState.Body => true + case ForState.VariableCleanup => dropVariablesExec.getTreeIterator.hasNext + }) override def next(): CompoundStatementExec = state match { case ForState.VariableAssignment => - variablesMap = createVariablesMapFromRow(cachedQueryResult()(currRow)) + variablesMap = createVariablesMapFromRow(cachedQueryResult().next()) if (!areVariablesDeclared) { // create and execute declare var statements @@ -821,8 +822,7 @@ class ForStatementExec( } private def switchStateFromBody(): Unit = { - currRow += 1 - state = if (currRow < cachedQueryResult().length) ForState.VariableAssignment + state = if (cachedQueryResult().hasNext) ForState.VariableAssignment else { // create compound body for dropping nodes after execution is complete dropVariablesExec = new CompoundBodyExec( @@ -862,7 +862,6 @@ class ForStatementExec( override def reset(): Unit = { state = ForState.VariableAssignment isResultCacheValid = false - currRow = 0 variablesMap = Map() areVariablesDeclared = false dropVariablesExec = null From 9bf0b7a0142fee73de84309af4b3d26d0d014d91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Mon, 25 Nov 2024 11:56:31 +0100 Subject: [PATCH 37/39] fix scalastyle --- .../spark/sql/scripting/SqlScriptingExecutionNode.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index f7cefac8d1068..9a42efa69a02c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.scripting +import java.util + import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} @@ -27,8 +29,6 @@ import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} import org.apache.spark.sql.errors.SqlScriptingErrors import org.apache.spark.sql.types.BooleanType -import java.util - /** * Trait for all SQL scripting execution nodes used during interpretation phase. */ From 9d1cf293aefd81d8e161e52ca33129660d4b5f3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Tue, 26 Nov 2024 13:32:19 +0100 Subject: [PATCH 38/39] add sum test --- .../SqlScriptingInterpreterSuite.scala | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 2dcbf06c7ecac..998944d08951c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1618,6 +1618,38 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + test("for statement - sum of column from table") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE sumOfCols = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (1), (2), (3), (4); + | FOR row AS SELECT * FROM t DO + | SET sumOfCols = sumOfCols + row.intCol; + | END FOR; + | SELECT sumOfCols; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare sumOfCols + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // drop local var + Seq.empty[Row], // drop local var + Seq(Row(10)), // select sumOfCols + Seq.empty[Row] // drop sumOfCols + ) + verifySqlScriptResult(sqlScript, expected) + } + } + test("for statement - map, struct, array") { withTable("t") { val sqlScript = @@ -2139,6 +2171,37 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { } } + test("for statement - no variable - sum of column from table") { + withTable("t") { + val sqlScript = + """ + |BEGIN + | DECLARE sumOfCols = 0; + | CREATE TABLE t (intCol INT) using parquet; + | INSERT INTO t VALUES (1), (2), (3), (4); + | FOR SELECT * FROM t DO + | SET sumOfCols = sumOfCols + intCol; + | END FOR; + | SELECT sumOfCols; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare sumOfCols + Seq.empty[Row], // create table + Seq.empty[Row], // insert + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // set sumOfCols + Seq.empty[Row], // drop local var + Seq(Row(10)), // select sumOfCols + Seq.empty[Row] // drop sumOfCols + ) + verifySqlScriptResult(sqlScript, expected) + } + } + test("for statement - no variable - map, struct, array") { withTable("t") { val sqlScript = From 3b3aebe883a5c113662301ba2d7d75032defde29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Wed, 27 Nov 2024 14:57:02 +0100 Subject: [PATCH 39/39] fix exec node test --- .../spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 0fb0f52a6e884..a997b5beadd34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -99,6 +99,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi extends SingleStatementExec( DummyLogicalPlan(), Origin(startIndex = Some(0), stopIndex = Some(description.length)), + Map.empty, isInternal = false) { override def buildDataFrame(session: SparkSession): DataFrame = { val data = Seq.range(0, numberOfRows).map(Row(_))