diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index cd7af021d8ff..523b7c88fc8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -445,7 +445,7 @@ class AstBuilder extends DataTypeAstBuilder
private def visitSearchedCaseStatementImpl(
ctx: SearchedCaseStatementContext,
- labelCtx: SqlScriptingLabelContext): CaseStatement = {
+ labelCtx: SqlScriptingLabelContext): SearchedCaseStatement = {
val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
@@ -464,7 +464,7 @@ class AstBuilder extends DataTypeAstBuilder
s" ${conditionalBodies.length} in case statement")
}
- CaseStatement(
+ SearchedCaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(
@@ -475,30 +475,31 @@ class AstBuilder extends DataTypeAstBuilder
private def visitSimpleCaseStatementImpl(
ctx: SimpleCaseStatementContext,
- labelCtx: SqlScriptingLabelContext): CaseStatement = {
- // uses EqualTo to compare the case variable(the main case expression)
- // to the WHEN clause expressions
- val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) {
- SingleStatement(
- Project(
- Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)), "condition")()),
- OneRowRelation()))
- })
+ labelCtx: SqlScriptingLabelContext): SimpleCaseStatement = {
+ val caseVariableExpr = withOrigin(ctx.caseVariable) {
+ expression(ctx.caseVariable)
+ }
+ val conditionExpressions =
+ ctx.conditionExpressions.asScala.toList
+ .map(exprCtx => withOrigin(exprCtx) {
+ expression(exprCtx)
+ })
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
)
- if (conditions.length != conditionalBodies.length) {
+ if (conditionExpressions.length != conditionalBodies.length) {
throw SparkException.internalError(
- s"Mismatched number of conditions ${conditions.length} and condition bodies" +
+ s"Mismatched number of conditions ${conditionExpressions.length} and condition bodies" +
s" ${conditionalBodies.length} in case statement")
}
- CaseStatement(
- conditions = conditions,
- conditionalBodies = conditionalBodies,
+ SimpleCaseStatement(
+ caseVariableExpr,
+ conditionExpressions,
+ conditionalBodies,
elseBody = Option(ctx.elseBody).map(
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
index bbbdd3b09a3c..2073b296a0dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/SqlScriptingLogicalPlans.scala
@@ -21,7 +21,7 @@ import java.util.Locale
import scala.collection.mutable
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.errors.SqlScriptingErrors
@@ -220,13 +220,24 @@ case class IterateStatement(label: String) extends CompoundPlanStatement {
}
/**
- * Logical operator for CASE statement.
+ * Logical operator for CASE statement, SEARCHED variant.
+ * Example:
+ * {{{
+ * CASE
+ * WHEN x = 1 THEN
+ * SELECT 1;
+ * WHEN x = 2 THEN
+ * SELECT 2;
+ * ELSE
+ * SELECT 3;
+ * END CASE;
+ * }}}
* @param conditions Collection of conditions which correspond to WHEN clauses.
* @param conditionalBodies Collection of bodies that have a corresponding condition,
* in WHEN branches.
* @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
*/
-case class CaseStatement(
+case class SearchedCaseStatement(
conditions: Seq[SingleStatement],
conditionalBodies: Seq[CompoundBody],
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
@@ -253,7 +264,44 @@ case class CaseStatement(
conditionalBodies = conditionalBodies.dropRight(1)
elseBody = Some(conditionalBodies.last)
}
- CaseStatement(conditions, conditionalBodies, elseBody)
+ SearchedCaseStatement(conditions, conditionalBodies, elseBody)
+ }
+}
+
+/**
+ * Logical operator for CASE statement, SIMPLE variant.
+ * Example:
+ * {{{
+ * CASE x
+ * WHEN 1 THEN
+ * SELECT 1;
+ * WHEN 2 THEN
+ * SELECT 2;
+ * ELSE
+ * SELECT 3;
+ * END CASE;
+ * }}}
+ * @param caseVariableExpression Expression with which all conditionExpressions will be compared to.
+ * @param conditionExpressions Collection of expressions which correspond to WHEN clauses.
+ * @param conditionalBodies Collection of bodies that have a corresponding condition,
+ * in WHEN branches.
+ * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
+ */
+case class SimpleCaseStatement(
+ caseVariableExpression: Expression,
+ conditionExpressions: Seq[Expression],
+ conditionalBodies: Seq[CompoundBody],
+ elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
+ assert(conditionExpressions.length == conditionalBodies.length)
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override def children: Seq[LogicalPlan] = conditionalBodies
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
+ val conditionalBodies = newChildren.map(_.asInstanceOf[CompoundBody])
+ SimpleCaseStatement(caseVariableExpression, conditionExpressions, conditionalBodies, elseBody)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index 9de5d09feb76..99658e1ff294 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery}
+import org.apache.spark.sql.catalyst.expressions.{In, Literal, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans.SQLHelper
-import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ExceptionHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SetVariable, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, CreateVariable, ExceptionHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SearchedCaseStatement, SetVariable, SimpleCaseStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.internal.SQLConf
@@ -1462,8 +1462,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
- assert(tree.collection.head.isInstanceOf[CaseStatement])
- val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
+ val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
assert(caseStmt.conditions.length == 1)
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditions.head.getText == "1 = 1")
@@ -1502,9 +1502,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
- assert(tree.collection.head.isInstanceOf[CaseStatement])
+ assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
- val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
assert(caseStmt.conditions.length == 3)
assert(caseStmt.conditionalBodies.length == 3)
assert(caseStmt.elseBody.isEmpty)
@@ -1545,8 +1545,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
- assert(tree.collection.head.isInstanceOf[CaseStatement])
- val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
+ val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
assert(caseStmt.elseBody.isDefined)
assert(caseStmt.conditions.length == 1)
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
@@ -1574,9 +1574,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
- assert(tree.collection.head.isInstanceOf[CaseStatement])
+ assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
- val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
assert(caseStmt.conditions.length == 1)
assert(caseStmt.conditionalBodies.length == 1)
assert(caseStmt.elseBody.isEmpty)
@@ -1584,9 +1584,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditions.head.getText == "1 = 1")
- assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
+ assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SearchedCaseStatement])
val nestedCaseStmt =
- caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
+ caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SearchedCaseStatement]
assert(nestedCaseStmt.conditions.length == 1)
assert(nestedCaseStmt.conditionalBodies.length == 1)
@@ -1616,11 +1616,15 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
- assert(tree.collection.head.isInstanceOf[CaseStatement])
- val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
- assert(caseStmt.conditions.length == 1)
- assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
- checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1))
+ assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
+ val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
+ assert(caseStmt.caseVariableExpression == Literal(1))
+ assert(caseStmt.conditionExpressions.length == 1)
+ assert(caseStmt.conditionExpressions.head == Literal(1))
+
+ assert(caseStmt.conditionalBodies.length == 1)
+ assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 1")
}
test("simple case statement with empty body") {
@@ -1656,31 +1660,27 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
- assert(tree.collection.head.isInstanceOf[CaseStatement])
+ assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
- val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
- assert(caseStmt.conditions.length == 3)
+ val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
+ assert(caseStmt.caseVariableExpression == Literal(1))
+ assert(caseStmt.conditionExpressions.length == 3)
assert(caseStmt.conditionalBodies.length == 3)
assert(caseStmt.elseBody.isEmpty)
- assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
- checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1))
+ assert(caseStmt.conditionExpressions.head == Literal(1))
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 1")
- assert(caseStmt.conditions(1).isInstanceOf[SingleStatement])
- checkSimpleCaseStatementCondition(
- caseStmt.conditions(1), _ == Literal(1), _.isInstanceOf[ScalarSubquery])
+ assert(caseStmt.conditionExpressions(1).isInstanceOf[ScalarSubquery])
assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT * FROM b")
- assert(caseStmt.conditions(2).isInstanceOf[SingleStatement])
- checkSimpleCaseStatementCondition(
- caseStmt.conditions(2), _ == Literal(1), _.isInstanceOf[In])
+ assert(caseStmt.conditionExpressions(2).isInstanceOf[In])
assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement]
@@ -1701,12 +1701,17 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
- assert(tree.collection.head.isInstanceOf[CaseStatement])
- val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
+ assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
+ val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
+
+ assert(caseStmt.caseVariableExpression == Literal(1))
assert(caseStmt.elseBody.isDefined)
- assert(caseStmt.conditions.length == 1)
- assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
- checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1))
+ assert(caseStmt.conditionExpressions.length == 1)
+ assert(caseStmt.conditionExpressions.head == Literal(1))
+
+ assert(caseStmt.conditionalBodies.length == 1)
+ assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
+ .getText == "SELECT 42")
assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
@@ -1730,28 +1735,27 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
- assert(tree.collection.head.isInstanceOf[CaseStatement])
+ assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
- val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
- assert(caseStmt.conditions.length == 1)
+ val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
+
+ assert(caseStmt.caseVariableExpression.isInstanceOf[ScalarSubquery])
+ assert(caseStmt.conditionExpressions.length == 1)
assert(caseStmt.conditionalBodies.length == 1)
assert(caseStmt.elseBody.isEmpty)
- assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
- checkSimpleCaseStatementCondition(
- caseStmt.conditions.head, _.isInstanceOf[ScalarSubquery], _ == Literal(1))
+ assert(caseStmt.conditionExpressions.head == Literal(1))
- assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
+ assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SimpleCaseStatement])
val nestedCaseStmt =
- caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
+ caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SimpleCaseStatement]
- assert(nestedCaseStmt.conditions.length == 1)
+ assert(nestedCaseStmt.caseVariableExpression == Literal(2))
+ assert(nestedCaseStmt.conditionExpressions.length == 1)
assert(nestedCaseStmt.conditionalBodies.length == 1)
assert(nestedCaseStmt.elseBody.isDefined)
- assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement])
- checkSimpleCaseStatementCondition(
- nestedCaseStmt.conditions.head, _ == Literal(2), _ == Literal(2))
+ assert(nestedCaseStmt.conditionExpressions.head == Literal(2))
assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
@@ -2910,17 +2914,4 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
.replace("END", "")
.trim
}
-
- private def checkSimpleCaseStatementCondition(
- conditionStatement: SingleStatement,
- predicateLeft: Expression => Boolean,
- predicateRight: Expression => Boolean): Unit = {
- assert(conditionStatement.parsedPlan.isInstanceOf[Project])
- val project = conditionStatement.parsedPlan.asInstanceOf[Project]
- assert(project.projectList.head.isInstanceOf[Alias])
- assert(project.projectList.head.asInstanceOf[Alias].child.isInstanceOf[EqualTo])
- val equalTo = project.projectList.head.asInstanceOf[Alias].child.asInstanceOf[EqualTo]
- assert(predicateLeft(equalTo.left))
- assert(predicateRight(equalTo.right))
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index ce0876e8f629..62edf3e46477 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.{ExecuteImmediateQuery, NameParameterizedQuery, UnresolvedAttribute, UnresolvedIdentifier}
-import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, Expression, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, CreateMap, CreateNamedStruct, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DefaultValueExpression, LogicalPlan, OneRowRelation, Project, SetVariable}
import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
@@ -528,14 +528,14 @@ class WhileStatementExec(
}
/**
- * Executable node for CaseStatement.
+ * Executable node for SearchedCaseStatement.
* @param conditions Collection of executable conditions which correspond to WHEN clauses.
* @param conditionalBodies Collection of executable bodies that have a corresponding condition,
* in WHEN branches.
* @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
* @param session Spark session that SQL script is executed within.
*/
-class CaseStatementExec(
+class SearchedCaseStatementExec(
conditions: Seq[SingleStatementExec],
conditionalBodies: Seq[CompoundBodyExec],
elseBody: Option[CompoundBodyExec],
@@ -599,6 +599,118 @@ class CaseStatementExec(
}
}
+/**
+ * Executable node for SimpleCaseStatement.
+ * @param caseVariableExec Statement with which all conditionExpressions will be compared to.
+ * @param conditionExpressions Collection of expressions which correspond to WHEN clauses.
+ * @param conditionalBodies Collection of executable bodies that have a corresponding condition,
+ * in WHEN branches.
+ * @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
+ * @param session Spark session that SQL script is executed within.
+ * @param context SqlScriptingExecutionContext keeps the execution state of current script.
+ */
+class SimpleCaseStatementExec(
+ caseVariableExec: SingleStatementExec,
+ conditionExpressions: Seq[Expression],
+ conditionalBodies: Seq[CompoundBodyExec],
+ elseBody: Option[CompoundBodyExec],
+ session: SparkSession,
+ context: SqlScriptingExecutionContext) extends NonLeafStatementExec {
+ private object CaseState extends Enumeration {
+ val Condition, Body = Value
+ }
+
+ private var state = CaseState.Condition
+ var bodyExec: Option[CompoundBodyExec] = None
+
+ var conditionBodyTupleIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = _
+ private var caseVariableLiteral: Literal = _
+
+ private var isCacheValid = false
+ private def validateCache(): Unit = {
+ if (!isCacheValid) {
+ val values = caseVariableExec.buildDataFrame(session).collect()
+ caseVariableExec.isExecuted = true
+
+ caseVariableLiteral = Literal(values.head.get(0))
+ conditionBodyTupleIterator = createConditionBodyIterator
+ isCacheValid = true
+ }
+ }
+
+ private def cachedCaseVariableLiteral: Literal = {
+ validateCache()
+ caseVariableLiteral
+ }
+
+ private def cachedConditionBodyIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = {
+ validateCache()
+ conditionBodyTupleIterator
+ }
+
+ private lazy val treeIterator: Iterator[CompoundStatementExec] =
+ new Iterator[CompoundStatementExec] {
+ override def hasNext: Boolean = state match {
+ case CaseState.Condition => cachedConditionBodyIterator.hasNext || elseBody.isDefined
+ case CaseState.Body => bodyExec.exists(_.getTreeIterator.hasNext)
+ }
+
+ override def next(): CompoundStatementExec = state match {
+ case CaseState.Condition =>
+ cachedConditionBodyIterator.nextOption()
+ .map { case (condStmt, body) =>
+ if (evaluateBooleanCondition(session, condStmt)) {
+ bodyExec = Some(body)
+ state = CaseState.Body
+ }
+ condStmt
+ }
+ .orElse(elseBody.map { body => {
+ bodyExec = Some(body)
+ state = CaseState.Body
+ next()
+ }})
+ .get
+ case CaseState.Body => bodyExec.get.getTreeIterator.next()
+ }
+ }
+
+ private def createConditionBodyIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] =
+ conditionExpressions.zip(conditionalBodies)
+ .iterator
+ .map { case (expr, body) =>
+ val condition = Project(
+ Seq(Alias(EqualTo(cachedCaseVariableLiteral, expr), "condition")()),
+ OneRowRelation()
+ )
+ // We hack the Origin to provide more descriptive error messages. For example, if
+ // the case variable is 1 and the condition expression it's compared to is 5, we
+ // will get Origin with text "(1 = 5)".
+ val conditionText = condition.projectList.head.asInstanceOf[Alias].child.toString
+ val condStmt = new SingleStatementExec(
+ condition,
+ Origin(sqlText = Some(conditionText),
+ startIndex = Some(0),
+ stopIndex = Some(conditionText.length - 1),
+ line = caseVariableExec.origin.line),
+ Map.empty,
+ isInternal = true,
+ context = context
+ )
+ (condStmt, body)
+ }
+
+ override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+
+ override def reset(): Unit = {
+ state = CaseState.Condition
+ isCacheValid = false
+ caseVariableExec.reset()
+ conditionalBodies.foreach(b => b.reset())
+ elseBody.foreach(b => b.reset())
+ }
+}
+
/**
* Executable node for RepeatStatement.
* @param condition Executable node for the condition - evaluates to a row with a single boolean
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
index 919d63bb0c71..4485e780c8bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.scripting
import scala.collection.mutable.HashMap
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CompoundPlanStatement, ExceptionHandlerType, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, RepeatStatement, SingleStatement, WhileStatement}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, CompoundPlanStatement, ExceptionHandlerType, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, OneRowRelation, Project, RepeatStatement, SearchedCaseStatement, SimpleCaseStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.errors.SqlScriptingErrors
@@ -187,7 +187,7 @@ case class SqlScriptingInterpreter(session: SparkSession) {
new IfElseStatementExec(
conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session)
- case CaseStatement(conditions, conditionalBodies, elseBody) =>
+ case SearchedCaseStatement(conditions, conditionalBodies, elseBody) =>
val conditionsExec = conditions.map(condition =>
new SingleStatementExec(
condition.parsedPlan,
@@ -199,9 +199,25 @@ case class SqlScriptingInterpreter(session: SparkSession) {
transformTreeIntoExecutable(body, args, context).asInstanceOf[CompoundBodyExec])
val unconditionalBodiesExec = elseBody.map(body =>
transformTreeIntoExecutable(body, args, context).asInstanceOf[CompoundBodyExec])
- new CaseStatementExec(
+ new SearchedCaseStatementExec(
conditionsExec, conditionalBodiesExec, unconditionalBodiesExec, session)
+ case SimpleCaseStatement(caseExpr, conditionExpressions, conditionalBodies, elseBody) =>
+ val caseValueStmt = SingleStatement(
+ Project(Seq(Alias(caseExpr, "caseVariable")()), OneRowRelation()))
+ val caseVarExec = new SingleStatementExec(
+ caseValueStmt.parsedPlan,
+ caseExpr.origin,
+ args,
+ isInternal = true,
+ context)
+ val conditionalBodiesExec = conditionalBodies.map(body =>
+ transformTreeIntoExecutable(body, args, context).asInstanceOf[CompoundBodyExec])
+ val elseBodyExec = elseBody.map(body =>
+ transformTreeIntoExecutable(body, args, context).asInstanceOf[CompoundBodyExec])
+ new SimpleCaseStatementExec(
+ caseVarExec, conditionExpressions, conditionalBodiesExec, elseBodyExec, session, context)
+
case WhileStatement(condition, body, label) =>
val conditionExec =
new SingleStatementExec(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
index f83ae87290ec..2d732f6f453a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala
@@ -77,10 +77,28 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
null
)
+ case class TestIntegerProjection(value: Int, description: String)
+ extends SingleStatementExec(
+ parsedPlan = Project(Seq(Alias(Literal(value), description)()), OneRowRelation()),
+ Origin(startIndex = Some(0), stopIndex = Some(description.length)),
+ Map.empty,
+ isInternal = false,
+ null
+ )
+
case class DummyLogicalPlan() extends LeafNode {
override def output: Seq[Attribute] = Seq.empty
}
+ case class MockScriptingContext() extends SqlScriptingExecutionContext {
+ override def enterScope(
+ label: String,
+ triggerHandlerMap: TriggerToExceptionHandlerMap
+ ): Unit = ()
+
+ override def exitScope(label: String): Unit = ()
+ }
+
case class TestLoopCondition(
condVal: Boolean, reps: Int, description: String)
extends SingleStatementExec(
@@ -162,7 +180,9 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
=> "DropVariable"
case execImm: SingleStatementExec if execImm.parsedPlan.isInstanceOf[ExecuteImmediateQuery]
=> "ExecuteImmediate"
- case _ => fail("Unexpected statement type")
+ case project: SingleStatementExec if project.parsedPlan.isInstanceOf[Project]
+ => "Project"
+ case _ => fail("Unexpected statement: " + statement)
}
// Tests
@@ -642,7 +662,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("searched case - enter first WHEN clause") {
val iter = TestCompoundBody(Seq(
- new CaseStatementExec(
+ new SearchedCaseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = true, description = "con1"),
TestIfElseCondition(condVal = false, description = "con2")
@@ -661,7 +681,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("searched case - enter body of the ELSE clause") {
val iter = TestCompoundBody(Seq(
- new CaseStatementExec(
+ new SearchedCaseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1")
),
@@ -678,7 +698,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("searched case - enter second WHEN clause") {
val iter = TestCompoundBody(Seq(
- new CaseStatementExec(
+ new SearchedCaseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1"),
TestIfElseCondition(condVal = true, description = "con2")
@@ -697,7 +717,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("searched case - without else (successful check)") {
val iter = TestCompoundBody(Seq(
- new CaseStatementExec(
+ new SearchedCaseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1"),
TestIfElseCondition(condVal = true, description = "con2")
@@ -716,7 +736,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
test("searched case - without else (unsuccessful checks)") {
val iter = TestCompoundBody(Seq(
- new CaseStatementExec(
+ new SearchedCaseStatementExec(
conditions = Seq(
TestIfElseCondition(condVal = false, description = "con1"),
TestIfElseCondition(condVal = false, description = "con2")
@@ -733,6 +753,109 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi
assert(statements === Seq("con1", "con2"))
}
+ test("simple case - enter first WHEN clause") {
+ val iter = TestCompoundBody(Seq(
+ new SimpleCaseStatementExec(
+ caseVariableExec = TestIntegerProjection(1, "1"),
+ conditionExpressions = Seq(
+ Literal(1),
+ Literal(2)
+ ),
+ conditionalBodies = Seq(
+ TestCompoundBody(Seq(TestLeafStatement("body1"))),
+ TestCompoundBody(Seq(TestLeafStatement("body2")))
+ ),
+ elseBody = Some(TestCompoundBody(Seq(TestLeafStatement("body3")))),
+ session = spark,
+ context = MockScriptingContext()
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("Project", "body1"))
+ }
+
+ test("simple case - enter body of the ELSE clause") {
+ val iter = TestCompoundBody(Seq(
+ new SimpleCaseStatementExec(
+ caseVariableExec = TestIntegerProjection(2, "2"),
+ conditionExpressions = Seq(
+ Literal(1)
+ ),
+ conditionalBodies = Seq(
+ TestCompoundBody(Seq(TestLeafStatement("body1")))
+ ),
+ elseBody = Some(TestCompoundBody(Seq(TestLeafStatement("body2")))),
+ session = spark,
+ context = MockScriptingContext()
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("Project", "body2"))
+ }
+
+ test("simple case - enter second WHEN clause") {
+ val iter = TestCompoundBody(Seq(
+ new SimpleCaseStatementExec(
+ caseVariableExec = TestIntegerProjection(2, "2"),
+ conditionExpressions = Seq(
+ Literal(1),
+ Literal(2)
+ ),
+ conditionalBodies = Seq(
+ TestCompoundBody(Seq(TestLeafStatement("body1"))),
+ TestCompoundBody(Seq(TestLeafStatement("body2")))
+ ),
+ elseBody = Some(TestCompoundBody(Seq(TestLeafStatement("body3")))),
+ session = spark,
+ context = MockScriptingContext()
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("Project", "Project", "body2"))
+ }
+
+ test("simple case - without else (successful check)") {
+ val iter = TestCompoundBody(Seq(
+ new SimpleCaseStatementExec(
+ caseVariableExec = TestIntegerProjection(2, "2"),
+ conditionExpressions = Seq(
+ Literal(1),
+ Literal(2)
+ ),
+ conditionalBodies = Seq(
+ TestCompoundBody(Seq(TestLeafStatement("body1"))),
+ TestCompoundBody(Seq(TestLeafStatement("body2")))
+ ),
+ elseBody = None,
+ session = spark,
+ context = MockScriptingContext()
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("Project", "Project", "body2"))
+ }
+
+ test("simple case - without else (unsuccessful checks)") {
+ val iter = TestCompoundBody(Seq(
+ new SimpleCaseStatementExec(
+ caseVariableExec = TestIntegerProjection(3, "3"),
+ conditionExpressions = Seq(
+ Literal(1),
+ Literal(2)
+ ),
+ conditionalBodies = Seq(
+ TestCompoundBody(Seq(TestLeafStatement("body1"))),
+ TestCompoundBody(Seq(TestLeafStatement("body2")))
+ ),
+ elseBody = None,
+ session = spark,
+ context = MockScriptingContext()
+ )
+ )).getTreeIterator
+ val statements = iter.map(extractStatementValue).toSeq
+ assert(statements === Seq("Project", "Project"))
+ }
+
test("loop statement with leave") {
val iter = TestCompoundBody(
statements = Seq(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
index 30efac0737dd..5a5f4185c50c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala
@@ -876,35 +876,164 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession {
"expression" -> "'one'",
"sourceType" -> "\"STRING\"",
"targetType" -> "\"BIGINT\""),
- context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27))
+ context = ExpectedContext(fragment = "", start = -1, stop = -1))
}
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
+ val e = intercept[SqlScriptingException](
+ runSqlScript(commands)
+ )
checkError(
- exception = intercept[SqlScriptingException](
- runSqlScript(commands)
- ),
+ exception = e,
condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
- parameters = Map("invalidStatement" -> "\"ONE\""))
+ parameters = Map("invalidStatement" -> "(1 = ONE)"))
+ assert(e.origin.line.contains(3))
}
}
- test("simple case compare with null") {
+ test("simple case with empty query result") {
withTable("t") {
val commands =
"""
|BEGIN
- | CREATE TABLE t (a INT) USING parquet;
- | CASE (SELECT COUNT(*) FROM t)
- | WHEN 1 THEN
- | SELECT 42;
- | ELSE
- | SELECT 43;
- | END CASE;
+ |CREATE TABLE t (a INT) USING parquet;
+ |CASE (SELECT * FROM t)
+ | WHEN 1 THEN
+ | SELECT 41;
+ | WHEN 2 THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
|END
|""".stripMargin
- val expected = Seq(Seq.empty[Row], Seq(Row(43)))
- verifySqlScriptResult(commands, expected)
+ val e = intercept[SqlScriptingException] {
+ verifySqlScriptResult(commands, Seq.empty)
+ }
+ checkError(
+ exception = e,
+ sqlState = "21000",
+ condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
+ parameters = Map("invalidStatement" -> "(NULL = 1)")
+ )
+ assert(e.origin.line.contains(4))
+ }
+ }
+
+ test("simple case with null comparison") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ |CASE 1
+ | WHEN NULL THEN
+ | SELECT 41;
+ | WHEN 2 THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ val e = intercept[SqlScriptingException] {
+ verifySqlScriptResult(commands, Seq.empty)
+ }
+ checkError(
+ exception = e,
+ sqlState = "21000",
+ condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
+ parameters = Map("invalidStatement" -> "(1 = NULL)")
+ )
+ assert(e.origin.line.contains(3))
+ }
+ }
+
+ test("simple case with null comparison 2") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ |CASE NULL
+ | WHEN 1 THEN
+ | SELECT 41;
+ | WHEN 2 THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ val e = intercept[SqlScriptingException] {
+ verifySqlScriptResult(commands, Seq.empty)
+ }
+ checkError(
+ exception = e,
+ sqlState = "21000",
+ condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW",
+ parameters = Map("invalidStatement" -> "(NULL = 1)")
+ )
+ assert(e.origin.line.contains(3))
+ }
+ }
+
+ test("simple case with multiple columns scalar subquery") {
+ val commands =
+ """
+ |BEGIN
+ |CASE (SELECT 1, 2)
+ | WHEN 1 THEN
+ | SELECT 41;
+ | WHEN 2 THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ val e = intercept[AnalysisException] {
+ verifySqlScriptResult(commands, Seq.empty)
+ }
+ checkError(
+ exception = e,
+ sqlState = "42823",
+ condition = "INVALID_SUBQUERY_EXPRESSION.SCALAR_SUBQUERY_RETURN_MORE_THAN_ONE_OUTPUT_COLUMN",
+ parameters = Map("number" -> "2"),
+ context = ExpectedContext(fragment = "(SELECT 1, 2)", start = 12, stop = 24)
+ )
+ }
+
+ test("simple case with multiple rows scalar subquery") {
+ withTable("t") {
+ val commands =
+ """
+ |BEGIN
+ |CREATE TABLE t (a INT) USING parquet;
+ |INSERT INTO t VALUES (1);
+ |INSERT INTO t VALUES (1);
+ |CASE (SELECT * FROM t)
+ | WHEN 1 THEN
+ | SELECT 41;
+ | WHEN 2 THEN
+ | SELECT 42;
+ | ELSE
+ | SELECT 43;
+ | END CASE;
+ |END
+ |""".stripMargin
+
+ val e = intercept[SparkException] {
+ verifySqlScriptResult(commands, Seq.empty)
+ }
+ checkError(
+ exception = e,
+ sqlState = "21000",
+ condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS",
+ parameters = Map.empty,
+ context = ExpectedContext(fragment = "(SELECT * FROM t)", start = 102, stop = 118)
+ )
}
}