diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cd7af021d8ff..523b7c88fc8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -445,7 +445,7 @@ class AstBuilder extends DataTypeAstBuilder private def visitSearchedCaseStatementImpl( ctx: SearchedCaseStatementContext, - labelCtx: SqlScriptingLabelContext): CaseStatement = { + labelCtx: SqlScriptingLabelContext): SearchedCaseStatement = { val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) { SingleStatement( Project( @@ -464,7 +464,7 @@ class AstBuilder extends DataTypeAstBuilder s" ${conditionalBodies.length} in case statement") } - CaseStatement( + SearchedCaseStatement( conditions = conditions, conditionalBodies = conditionalBodies, elseBody = Option(ctx.elseBody).map( @@ -475,30 +475,31 @@ class AstBuilder extends DataTypeAstBuilder private def visitSimpleCaseStatementImpl( ctx: SimpleCaseStatementContext, - labelCtx: SqlScriptingLabelContext): CaseStatement = { - // uses EqualTo to compare the case variable(the main case expression) - // to the WHEN clause expressions - val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) { - SingleStatement( - Project( - Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)), "condition")()), - OneRowRelation())) - }) + labelCtx: SqlScriptingLabelContext): SimpleCaseStatement = { + val caseVariableExpr = withOrigin(ctx.caseVariable) { + expression(ctx.caseVariable) + } + val conditionExpressions = + ctx.conditionExpressions.asScala.toList + .map(exprCtx => withOrigin(exprCtx) { + expression(exprCtx) + }) val conditionalBodies = ctx.conditionalBodies.asScala.toList.map( body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false) ) - if (conditions.length != conditionalBodies.length) { + if (conditionExpressions.length != conditionalBodies.length) { throw SparkException.internalError( - s"Mismatched number of conditions ${conditions.length} and condition bodies" + + s"Mismatched number of conditions ${conditionExpressions.length} and condition bodies" + s" ${conditionalBodies.length} in case statement") } - CaseStatement( - conditions = conditions, - conditionalBodies = conditionalBodies, + SimpleCaseStatement( + caseVariableExpr, + conditionExpressions, + conditionalBodies, elseBody = Option(ctx.elseBody).map( body => visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala index bbbdd3b09a3c..2073b296a0dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.errors.SqlScriptingErrors @@ -220,13 +220,24 @@ case class IterateStatement(label: String) extends CompoundPlanStatement { } /** - * Logical operator for CASE statement. + * Logical operator for CASE statement, SEARCHED variant.
+ * Example: + * {{{ + * CASE + * WHEN x = 1 THEN + * SELECT 1; + * WHEN x = 2 THEN + * SELECT 2; + * ELSE + * SELECT 3; + * END CASE; + * }}} * @param conditions Collection of conditions which correspond to WHEN clauses. * @param conditionalBodies Collection of bodies that have a corresponding condition, * in WHEN branches. * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. */ -case class CaseStatement( +case class SearchedCaseStatement( conditions: Seq[SingleStatement], conditionalBodies: Seq[CompoundBody], elseBody: Option[CompoundBody]) extends CompoundPlanStatement { @@ -253,7 +264,44 @@ case class CaseStatement( conditionalBodies = conditionalBodies.dropRight(1) elseBody = Some(conditionalBodies.last) } - CaseStatement(conditions, conditionalBodies, elseBody) + SearchedCaseStatement(conditions, conditionalBodies, elseBody) + } +} + +/** + * Logical operator for CASE statement, SIMPLE variant.
+ * Example: + * {{{ + * CASE x + * WHEN 1 THEN + * SELECT 1; + * WHEN 2 THEN + * SELECT 2; + * ELSE + * SELECT 3; + * END CASE; + * }}} + * @param caseVariableExpression Expression with which all conditionExpressions will be compared to. + * @param conditionExpressions Collection of expressions which correspond to WHEN clauses. + * @param conditionalBodies Collection of bodies that have a corresponding condition, + * in WHEN branches. + * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. + */ +case class SimpleCaseStatement( + caseVariableExpression: Expression, + conditionExpressions: Seq[Expression], + conditionalBodies: Seq[CompoundBody], + elseBody: Option[CompoundBody]) extends CompoundPlanStatement { + assert(conditionExpressions.length == conditionalBodies.length) + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = conditionalBodies + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = { + val conditionalBodies = newChildren.map(_.asInstanceOf[CompoundBody]) + SimpleCaseStatement(caseVariableExpression, conditionExpressions, conditionalBodies, elseBody) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 9de5d09feb76..99658e1ff294 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.{In, Literal, ScalarSubquery} import org.apache.spark.sql.catalyst.plans.SQLHelper -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ExceptionHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SetVariable, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, CreateVariable, ExceptionHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SearchedCaseStatement, SetVariable, SimpleCaseStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.exceptions.SqlScriptingException import org.apache.spark.sql.internal.SQLConf @@ -1462,8 +1462,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { |""".stripMargin val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[CaseStatement]) - val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(tree.collection.head.isInstanceOf[SearchedCaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement] assert(caseStmt.conditions.length == 1) assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) assert(caseStmt.conditions.head.getText == "1 = 1") @@ -1502,9 +1502,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[CaseStatement]) + assert(tree.collection.head.isInstanceOf[SearchedCaseStatement]) - val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement] assert(caseStmt.conditions.length == 3) assert(caseStmt.conditionalBodies.length == 3) assert(caseStmt.elseBody.isEmpty) @@ -1545,8 +1545,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { |""".stripMargin val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[CaseStatement]) - val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(tree.collection.head.isInstanceOf[SearchedCaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement] assert(caseStmt.elseBody.isDefined) assert(caseStmt.conditions.length == 1) assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) @@ -1574,9 +1574,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { |""".stripMargin val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[CaseStatement]) + assert(tree.collection.head.isInstanceOf[SearchedCaseStatement]) - val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement] assert(caseStmt.conditions.length == 1) assert(caseStmt.conditionalBodies.length == 1) assert(caseStmt.elseBody.isEmpty) @@ -1584,9 +1584,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) assert(caseStmt.conditions.head.getText == "1 = 1") - assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement]) + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SearchedCaseStatement]) val nestedCaseStmt = - caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement] + caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SearchedCaseStatement] assert(nestedCaseStmt.conditions.length == 1) assert(nestedCaseStmt.conditionalBodies.length == 1) @@ -1616,11 +1616,15 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { |""".stripMargin val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[CaseStatement]) - val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] - assert(caseStmt.conditions.length == 1) - assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) - checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) + assert(tree.collection.head.isInstanceOf[SimpleCaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement] + assert(caseStmt.caseVariableExpression == Literal(1)) + assert(caseStmt.conditionExpressions.length == 1) + assert(caseStmt.conditionExpressions.head == Literal(1)) + + assert(caseStmt.conditionalBodies.length == 1) + assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 1") } test("simple case statement with empty body") { @@ -1656,31 +1660,27 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[CaseStatement]) + assert(tree.collection.head.isInstanceOf[SimpleCaseStatement]) - val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] - assert(caseStmt.conditions.length == 3) + val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement] + assert(caseStmt.caseVariableExpression == Literal(1)) + assert(caseStmt.conditionExpressions.length == 3) assert(caseStmt.conditionalBodies.length == 3) assert(caseStmt.elseBody.isEmpty) - assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) - checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) + assert(caseStmt.conditionExpressions.head == Literal(1)) assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] .getText == "SELECT 1") - assert(caseStmt.conditions(1).isInstanceOf[SingleStatement]) - checkSimpleCaseStatementCondition( - caseStmt.conditions(1), _ == Literal(1), _.isInstanceOf[ScalarSubquery]) + assert(caseStmt.conditionExpressions(1).isInstanceOf[ScalarSubquery]) assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement]) assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement] .getText == "SELECT * FROM b") - assert(caseStmt.conditions(2).isInstanceOf[SingleStatement]) - checkSimpleCaseStatementCondition( - caseStmt.conditions(2), _ == Literal(1), _.isInstanceOf[In]) + assert(caseStmt.conditionExpressions(2).isInstanceOf[In]) assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement]) assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement] @@ -1701,12 +1701,17 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { |""".stripMargin val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[CaseStatement]) - val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] + assert(tree.collection.head.isInstanceOf[SimpleCaseStatement]) + val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement] + + assert(caseStmt.caseVariableExpression == Literal(1)) assert(caseStmt.elseBody.isDefined) - assert(caseStmt.conditions.length == 1) - assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) - checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1)) + assert(caseStmt.conditionExpressions.length == 1) + assert(caseStmt.conditionExpressions.head == Literal(1)) + + assert(caseStmt.conditionalBodies.length == 1) + assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 42") assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] @@ -1730,28 +1735,27 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { |""".stripMargin val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody] assert(tree.collection.length == 1) - assert(tree.collection.head.isInstanceOf[CaseStatement]) + assert(tree.collection.head.isInstanceOf[SimpleCaseStatement]) - val caseStmt = tree.collection.head.asInstanceOf[CaseStatement] - assert(caseStmt.conditions.length == 1) + val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement] + + assert(caseStmt.caseVariableExpression.isInstanceOf[ScalarSubquery]) + assert(caseStmt.conditionExpressions.length == 1) assert(caseStmt.conditionalBodies.length == 1) assert(caseStmt.elseBody.isEmpty) - assert(caseStmt.conditions.head.isInstanceOf[SingleStatement]) - checkSimpleCaseStatementCondition( - caseStmt.conditions.head, _.isInstanceOf[ScalarSubquery], _ == Literal(1)) + assert(caseStmt.conditionExpressions.head == Literal(1)) - assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement]) + assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SimpleCaseStatement]) val nestedCaseStmt = - caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement] + caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SimpleCaseStatement] - assert(nestedCaseStmt.conditions.length == 1) + assert(nestedCaseStmt.caseVariableExpression == Literal(2)) + assert(nestedCaseStmt.conditionExpressions.length == 1) assert(nestedCaseStmt.conditionalBodies.length == 1) assert(nestedCaseStmt.elseBody.isDefined) - assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement]) - checkSimpleCaseStatementCondition( - nestedCaseStmt.conditions.head, _ == Literal(2), _ == Literal(2)) + assert(nestedCaseStmt.conditionExpressions.head == Literal(2)) assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] @@ -2910,17 +2914,4 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .replace("END", "") .trim } - - private def checkSimpleCaseStatementCondition( - conditionStatement: SingleStatement, - predicateLeft: Expression => Boolean, - predicateRight: Expression => Boolean): Unit = { - assert(conditionStatement.parsedPlan.isInstanceOf[Project]) - val project = conditionStatement.parsedPlan.asInstanceOf[Project] - assert(project.projectList.head.isInstanceOf[Alias]) - assert(project.projectList.head.asInstanceOf[Alias].child.isInstanceOf[EqualTo]) - val equalTo = project.projectList.head.asInstanceOf[Alias].child.asInstanceOf[EqualTo] - assert(predicateLeft(equalTo.left)) - assert(predicateRight(equalTo.right)) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index ce0876e8f629..62edf3e46477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.{ExecuteImmediateQuery, NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier} -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, EqualTo, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable} import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin} @@ -528,14 +528,14 @@ class WhileStatementExec( } /** - * Executable node for CaseStatement. + * Executable node for SearchedCaseStatement. * @param conditions Collection of executable conditions which correspond to WHEN clauses. * @param conditionalBodies Collection of executable bodies that have a corresponding condition, * in WHEN branches. * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. * @param session Spark session that SQL script is executed within. */ -class CaseStatementExec( +class SearchedCaseStatementExec( conditions: Seq[SingleStatementExec], conditionalBodies: Seq[CompoundBodyExec], elseBody: Option[CompoundBodyExec], @@ -599,6 +599,118 @@ class CaseStatementExec( } } +/** + * Executable node for SimpleCaseStatement. + * @param caseVariableExec Statement with which all conditionExpressions will be compared to. + * @param conditionExpressions Collection of expressions which correspond to WHEN clauses. + * @param conditionalBodies Collection of executable bodies that have a corresponding condition, + * in WHEN branches. + * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch. + * @param session Spark session that SQL script is executed within. + * @param context SqlScriptingExecutionContext keeps the execution state of current script. + */ +class SimpleCaseStatementExec( + caseVariableExec: SingleStatementExec, + conditionExpressions: Seq[Expression], + conditionalBodies: Seq[CompoundBodyExec], + elseBody: Option[CompoundBodyExec], + session: SparkSession, + context: SqlScriptingExecutionContext) extends NonLeafStatementExec { + private object CaseState extends Enumeration { + val Condition, Body = Value + } + + private var state = CaseState.Condition + var bodyExec: Option[CompoundBodyExec] = None + + var conditionBodyTupleIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = _ + private var caseVariableLiteral: Literal = _ + + private var isCacheValid = false + private def validateCache(): Unit = { + if (!isCacheValid) { + val values = caseVariableExec.buildDataFrame(session).collect() + caseVariableExec.isExecuted = true + + caseVariableLiteral = Literal(values.head.get(0)) + conditionBodyTupleIterator = createConditionBodyIterator + isCacheValid = true + } + } + + private def cachedCaseVariableLiteral: Literal = { + validateCache() + caseVariableLiteral + } + + private def cachedConditionBodyIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = { + validateCache() + conditionBodyTupleIterator + } + + private lazy val treeIterator: Iterator[CompoundStatementExec] = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = state match { + case CaseState.Condition => cachedConditionBodyIterator.hasNext || elseBody.isDefined + case CaseState.Body => bodyExec.exists(_.getTreeIterator.hasNext) + } + + override def next(): CompoundStatementExec = state match { + case CaseState.Condition => + cachedConditionBodyIterator.nextOption() + .map { case (condStmt, body) => + if (evaluateBooleanCondition(session, condStmt)) { + bodyExec = Some(body) + state = CaseState.Body + } + condStmt + } + .orElse(elseBody.map { body => { + bodyExec = Some(body) + state = CaseState.Body + next() + }}) + .get + case CaseState.Body => bodyExec.get.getTreeIterator.next() + } + } + + private def createConditionBodyIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = + conditionExpressions.zip(conditionalBodies) + .iterator + .map { case (expr, body) => + val condition = Project( + Seq(Alias(EqualTo(cachedCaseVariableLiteral, expr), "condition")()), + OneRowRelation() + ) + // We hack the Origin to provide more descriptive error messages. For example, if + // the case variable is 1 and the condition expression it's compared to is 5, we + // will get Origin with text "(1 = 5)". + val conditionText = condition.projectList.head.asInstanceOf[Alias].child.toString + val condStmt = new SingleStatementExec( + condition, + Origin(sqlText = Some(conditionText), + startIndex = Some(0), + stopIndex = Some(conditionText.length - 1), + line = caseVariableExec.origin.line), + Map.empty, + isInternal = true, + context = context + ) + (condStmt, body) + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + state = CaseState.Condition + isCacheValid = false + caseVariableExec.reset() + conditionalBodies.foreach(b => b.reset()) + elseBody.foreach(b => b.reset()) + } +} + /** * Executable node for RepeatStatement. * @param condition Executable node for the condition - evaluates to a row with a single boolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 919d63bb0c71..4485e780c8bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.scripting import scala.collection.mutable.HashMap import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, ExceptionHandlerType, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, CompoundPlanStatement, ExceptionHandlerType, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, OneRowRelation, Project, RepeatStatement, SearchedCaseStatement, SimpleCaseStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.trees.CurrentOrigin import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.errors.SqlScriptingErrors @@ -187,7 +187,7 @@ case class SqlScriptingInterpreter(session: SparkSession) { new IfElseStatementExec( conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) - case CaseStatement(conditions, conditionalBodies, elseBody) => + case SearchedCaseStatement(conditions, conditionalBodies, elseBody) => val conditionsExec = conditions.map(condition => new SingleStatementExec( condition.parsedPlan, @@ -199,9 +199,25 @@ case class SqlScriptingInterpreter(session: SparkSession) { transformTreeIntoExecutable(body, args, context).asInstanceOf[CompoundBodyExec]) val unconditionalBodiesExec = elseBody.map(body => transformTreeIntoExecutable(body, args, context).asInstanceOf[CompoundBodyExec]) - new CaseStatementExec( + new SearchedCaseStatementExec( conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session) + case SimpleCaseStatement(caseExpr, conditionExpressions, conditionalBodies, elseBody) => + val caseValueStmt = SingleStatement( + Project(Seq(Alias(caseExpr, "caseVariable")()), OneRowRelation())) + val caseVarExec = new SingleStatementExec( + caseValueStmt.parsedPlan, + caseExpr.origin, + args, + isInternal = true, + context) + val conditionalBodiesExec = conditionalBodies.map(body => + transformTreeIntoExecutable(body, args, context).asInstanceOf[CompoundBodyExec]) + val elseBodyExec = elseBody.map(body => + transformTreeIntoExecutable(body, args, context).asInstanceOf[CompoundBodyExec]) + new SimpleCaseStatementExec( + caseVarExec, conditionExpressions, conditionalBodiesExec, elseBodyExec, session, context) + case WhileStatement(condition, body, label) => val conditionExec = new SingleStatementExec( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index f83ae87290ec..2d732f6f453a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -77,10 +77,28 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi null ) + case class TestIntegerProjection(value: Int, description: String) + extends SingleStatementExec( + parsedPlan = Project(Seq(Alias(Literal(value), description)()), OneRowRelation()), + Origin(startIndex = Some(0), stopIndex = Some(description.length)), + Map.empty, + isInternal = false, + null + ) + case class DummyLogicalPlan() extends LeafNode { override def output: Seq[Attribute] = Seq.empty } + case class MockScriptingContext() extends SqlScriptingExecutionContext { + override def enterScope( + label: String, + triggerHandlerMap: TriggerToExceptionHandlerMap + ): Unit = () + + override def exitScope(label: String): Unit = () + } + case class TestLoopCondition( condVal: Boolean, reps: Int, description: String) extends SingleStatementExec( @@ -162,7 +180,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi => "DropVariable" case execImm: SingleStatementExec if execImm.parsedPlan.isInstanceOf[ExecuteImmediateQuery] => "ExecuteImmediate" - case _ => fail("Unexpected statement type") + case project: SingleStatementExec if project.parsedPlan.isInstanceOf[Project] + => "Project" + case _ => fail("Unexpected statement: " + statement) } // Tests @@ -642,7 +662,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("searched case - enter first WHEN clause") { val iter = TestCompoundBody(Seq( - new CaseStatementExec( + new SearchedCaseStatementExec( conditions = Seq( TestIfElseCondition(condVal = true, description = "con1"), TestIfElseCondition(condVal = false, description = "con2") @@ -661,7 +681,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("searched case - enter body of the ELSE clause") { val iter = TestCompoundBody(Seq( - new CaseStatementExec( + new SearchedCaseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1") ), @@ -678,7 +698,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("searched case - enter second WHEN clause") { val iter = TestCompoundBody(Seq( - new CaseStatementExec( + new SearchedCaseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1"), TestIfElseCondition(condVal = true, description = "con2") @@ -697,7 +717,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("searched case - without else (successful check)") { val iter = TestCompoundBody(Seq( - new CaseStatementExec( + new SearchedCaseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1"), TestIfElseCondition(condVal = true, description = "con2") @@ -716,7 +736,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi test("searched case - without else (unsuccessful checks)") { val iter = TestCompoundBody(Seq( - new CaseStatementExec( + new SearchedCaseStatementExec( conditions = Seq( TestIfElseCondition(condVal = false, description = "con1"), TestIfElseCondition(condVal = false, description = "con2") @@ -733,6 +753,109 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi assert(statements === Seq("con1", "con2")) } + test("simple case - enter first WHEN clause") { + val iter = TestCompoundBody(Seq( + new SimpleCaseStatementExec( + caseVariableExec = TestIntegerProjection(1, "1"), + conditionExpressions = Seq( + Literal(1), + Literal(2) + ), + conditionalBodies = Seq( + TestCompoundBody(Seq(TestLeafStatement("body1"))), + TestCompoundBody(Seq(TestLeafStatement("body2"))) + ), + elseBody = Some(TestCompoundBody(Seq(TestLeafStatement("body3")))), + session = spark, + context = MockScriptingContext() + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("Project", "body1")) + } + + test("simple case - enter body of the ELSE clause") { + val iter = TestCompoundBody(Seq( + new SimpleCaseStatementExec( + caseVariableExec = TestIntegerProjection(2, "2"), + conditionExpressions = Seq( + Literal(1) + ), + conditionalBodies = Seq( + TestCompoundBody(Seq(TestLeafStatement("body1"))) + ), + elseBody = Some(TestCompoundBody(Seq(TestLeafStatement("body2")))), + session = spark, + context = MockScriptingContext() + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("Project", "body2")) + } + + test("simple case - enter second WHEN clause") { + val iter = TestCompoundBody(Seq( + new SimpleCaseStatementExec( + caseVariableExec = TestIntegerProjection(2, "2"), + conditionExpressions = Seq( + Literal(1), + Literal(2) + ), + conditionalBodies = Seq( + TestCompoundBody(Seq(TestLeafStatement("body1"))), + TestCompoundBody(Seq(TestLeafStatement("body2"))) + ), + elseBody = Some(TestCompoundBody(Seq(TestLeafStatement("body3")))), + session = spark, + context = MockScriptingContext() + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("Project", "Project", "body2")) + } + + test("simple case - without else (successful check)") { + val iter = TestCompoundBody(Seq( + new SimpleCaseStatementExec( + caseVariableExec = TestIntegerProjection(2, "2"), + conditionExpressions = Seq( + Literal(1), + Literal(2) + ), + conditionalBodies = Seq( + TestCompoundBody(Seq(TestLeafStatement("body1"))), + TestCompoundBody(Seq(TestLeafStatement("body2"))) + ), + elseBody = None, + session = spark, + context = MockScriptingContext() + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("Project", "Project", "body2")) + } + + test("simple case - without else (unsuccessful checks)") { + val iter = TestCompoundBody(Seq( + new SimpleCaseStatementExec( + caseVariableExec = TestIntegerProjection(3, "3"), + conditionExpressions = Seq( + Literal(1), + Literal(2) + ), + conditionalBodies = Seq( + TestCompoundBody(Seq(TestLeafStatement("body1"))), + TestCompoundBody(Seq(TestLeafStatement("body2"))) + ), + elseBody = None, + session = spark, + context = MockScriptingContext() + ) + )).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("Project", "Project")) + } + test("loop statement with leave") { val iter = TestCompoundBody( statements = Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index 30efac0737dd..5a5f4185c50c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -876,35 +876,164 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { "expression" -> "'one'", "sourceType" -> "\"STRING\"", "targetType" -> "\"BIGINT\""), - context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27)) + context = ExpectedContext(fragment = "", start = -1, stop = -1)) } withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + val e = intercept[SqlScriptingException]( + runSqlScript(commands) + ) checkError( - exception = intercept[SqlScriptingException]( - runSqlScript(commands) - ), + exception = e, condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", - parameters = Map("invalidStatement" -> "\"ONE\"")) + parameters = Map("invalidStatement" -> "(1 = ONE)")) + assert(e.origin.line.contains(3)) } } - test("simple case compare with null") { + test("simple case with empty query result") { withTable("t") { val commands = """ |BEGIN - | CREATE TABLE t (a INT) USING parquet; - | CASE (SELECT COUNT(*) FROM t) - | WHEN 1 THEN - | SELECT 42; - | ELSE - | SELECT 43; - | END CASE; + |CREATE TABLE t (a INT) USING parquet; + |CASE (SELECT * FROM t) + | WHEN 1 THEN + | SELECT 41; + | WHEN 2 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; |END |""".stripMargin - val expected = Seq(Seq.empty[Row], Seq(Row(43))) - verifySqlScriptResult(commands, expected) + val e = intercept[SqlScriptingException] { + verifySqlScriptResult(commands, Seq.empty) + } + checkError( + exception = e, + sqlState = "21000", + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + parameters = Map("invalidStatement" -> "(NULL = 1)") + ) + assert(e.origin.line.contains(4)) + } + } + + test("simple case with null comparison") { + withTable("t") { + val commands = + """ + |BEGIN + |CASE 1 + | WHEN NULL THEN + | SELECT 41; + | WHEN 2 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val e = intercept[SqlScriptingException] { + verifySqlScriptResult(commands, Seq.empty) + } + checkError( + exception = e, + sqlState = "21000", + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + parameters = Map("invalidStatement" -> "(1 = NULL)") + ) + assert(e.origin.line.contains(3)) + } + } + + test("simple case with null comparison 2") { + withTable("t") { + val commands = + """ + |BEGIN + |CASE NULL + | WHEN 1 THEN + | SELECT 41; + | WHEN 2 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val e = intercept[SqlScriptingException] { + verifySqlScriptResult(commands, Seq.empty) + } + checkError( + exception = e, + sqlState = "21000", + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + parameters = Map("invalidStatement" -> "(NULL = 1)") + ) + assert(e.origin.line.contains(3)) + } + } + + test("simple case with multiple columns scalar subquery") { + val commands = + """ + |BEGIN + |CASE (SELECT 1, 2) + | WHEN 1 THEN + | SELECT 41; + | WHEN 2 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val e = intercept[AnalysisException] { + verifySqlScriptResult(commands, Seq.empty) + } + checkError( + exception = e, + sqlState = "42823", + condition = "INVALID_SUBQUERY_EXPRESSION.SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN", + parameters = Map("number" -> "2"), + context = ExpectedContext(fragment = "(SELECT 1, 2)", start = 12, stop = 24) + ) + } + + test("simple case with multiple rows scalar subquery") { + withTable("t") { + val commands = + """ + |BEGIN + |CREATE TABLE t (a INT) USING parquet; + |INSERT INTO t VALUES (1); + |INSERT INTO t VALUES (1); + |CASE (SELECT * FROM t) + | WHEN 1 THEN + | SELECT 41; + | WHEN 2 THEN + | SELECT 42; + | ELSE + | SELECT 43; + | END CASE; + |END + |""".stripMargin + + val e = intercept[SparkException] { + verifySqlScriptResult(commands, Seq.empty) + } + checkError( + exception = e, + sqlState = "21000", + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + parameters = Map.empty, + context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 102, stop = 118) + ) } }