Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class AstBuilder extends DataTypeAstBuilder

private def visitSearchedCaseStatementImpl(
ctx: SearchedCaseStatementContext,
labelCtx: SqlScriptingLabelContext): CaseStatement = {
labelCtx: SqlScriptingLabelContext): SearchedCaseStatement = {
val conditions = ctx.conditions.asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Expand All @@ -464,7 +464,7 @@ class AstBuilder extends DataTypeAstBuilder
s" ${conditionalBodies.length} in case statement")
}

CaseStatement(
SearchedCaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
elseBody = Option(ctx.elseBody).map(
Expand All @@ -475,30 +475,31 @@ class AstBuilder extends DataTypeAstBuilder

private def visitSimpleCaseStatementImpl(
ctx: SimpleCaseStatementContext,
labelCtx: SqlScriptingLabelContext): CaseStatement = {
// uses EqualTo to compare the case variable(the main case expression)
// to the WHEN clause expressions
val conditions = ctx.conditionExpressions.asScala.toList.map(expr => withOrigin(expr) {
SingleStatement(
Project(
Seq(Alias(EqualTo(expression(ctx.caseVariable), expression(expr)), "condition")()),
OneRowRelation()))
})
labelCtx: SqlScriptingLabelContext): SimpleCaseStatement = {
val caseVariableExpr = withOrigin(ctx.caseVariable) {
expression(ctx.caseVariable)
}
val conditionExpressions =
ctx.conditionExpressions.asScala.toList
.map(exprCtx => withOrigin(exprCtx) {
expression(exprCtx)
})
val conditionalBodies =
ctx.conditionalBodies.asScala.toList.map(
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
)

if (conditions.length != conditionalBodies.length) {
if (conditionExpressions.length != conditionalBodies.length) {
throw SparkException.internalError(
s"Mismatched number of conditions ${conditions.length} and condition bodies" +
s"Mismatched number of conditions ${conditionExpressions.length} and condition bodies" +
s" ${conditionalBodies.length} in case statement")
}

CaseStatement(
conditions = conditions,
conditionalBodies = conditionalBodies,
SimpleCaseStatement(
caseVariableExpr,
conditionExpressions,
conditionalBodies,
elseBody = Option(ctx.elseBody).map(
body =>
visitCompoundBodyImpl(body, None, allowVarDeclare = false, labelCtx, isScope = false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.ExceptionHandlerType.ExceptionHandlerType
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.errors.SqlScriptingErrors
Expand Down Expand Up @@ -220,13 +220,24 @@ case class IterateStatement(label: String) extends CompoundPlanStatement {
}

/**
* Logical operator for CASE statement.
* Logical operator for CASE statement, SEARCHED variant.<br>
* Example:
* {{{
* CASE
* WHEN x = 1 THEN
* SELECT 1;
* WHEN x = 2 THEN
* SELECT 2;
* ELSE
* SELECT 3;
* END CASE;
* }}}
* @param conditions Collection of conditions which correspond to WHEN clauses.
* @param conditionalBodies Collection of bodies that have a corresponding condition,
* in WHEN branches.
* @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
*/
case class CaseStatement(
case class SearchedCaseStatement(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update the description comment, to be equivalent of "SIMPLE variant" as below. Also, we might include a small examples to illustrate which is which, I always get confused...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

conditions: Seq[SingleStatement],
conditionalBodies: Seq[CompoundBody],
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
Expand All @@ -253,7 +264,44 @@ case class CaseStatement(
conditionalBodies = conditionalBodies.dropRight(1)
elseBody = Some(conditionalBodies.last)
}
CaseStatement(conditions, conditionalBodies, elseBody)
SearchedCaseStatement(conditions, conditionalBodies, elseBody)
}
}

/**
* Logical operator for CASE statement, SIMPLE variant.<br>
* Example:
* {{{
* CASE x
* WHEN 1 THEN
* SELECT 1;
* WHEN 2 THEN
* SELECT 2;
* ELSE
* SELECT 3;
* END CASE;
* }}}
* @param caseVariableExpression Expression with which all conditionExpressions will be compared to.
* @param conditionExpressions Collection of expressions which correspond to WHEN clauses.
* @param conditionalBodies Collection of bodies that have a corresponding condition,
* in WHEN branches.
* @param elseBody Body that is executed if none of the conditions are met, i.e. ELSE branch.
*/
case class SimpleCaseStatement(
caseVariableExpression: Expression,
conditionExpressions: Seq[Expression],
conditionalBodies: Seq[CompoundBody],
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
assert(conditionExpressions.length == conditionalBodies.length)

override def output: Seq[Attribute] = Seq.empty

override def children: Seq[LogicalPlan] = conditionalBodies

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = {
val conditionalBodies = newChildren.map(_.asInstanceOf[CompoundBody])
SimpleCaseStatement(caseVariableExpression, conditionExpressions, conditionalBodies, elseBody)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.parser

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Expression, In, Literal, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.{In, Literal, ScalarSubquery}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.plans.logical.{CaseStatement, CompoundBody, CreateVariable, ExceptionHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SetVariable, SingleStatement, WhileStatement}
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, CreateVariable, ExceptionHandler, ForStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, Project, RepeatStatement, SearchedCaseStatement, SetVariable, SimpleCaseStatement, SingleStatement, WhileStatement}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.exceptions.SqlScriptingException
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1462,8 +1462,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[CaseStatement])
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
assert(caseStmt.conditions.length == 1)
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditions.head.getText == "1 = 1")
Expand Down Expand Up @@ -1502,9 +1502,9 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]

assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[CaseStatement])
assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])

val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
assert(caseStmt.conditions.length == 3)
assert(caseStmt.conditionalBodies.length == 3)
assert(caseStmt.elseBody.isEmpty)
Expand Down Expand Up @@ -1545,8 +1545,8 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[CaseStatement])
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])
val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
assert(caseStmt.elseBody.isDefined)
assert(caseStmt.conditions.length == 1)
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
Expand Down Expand Up @@ -1574,19 +1574,19 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[CaseStatement])
assert(tree.collection.head.isInstanceOf[SearchedCaseStatement])

val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
val caseStmt = tree.collection.head.asInstanceOf[SearchedCaseStatement]
assert(caseStmt.conditions.length == 1)
assert(caseStmt.conditionalBodies.length == 1)
assert(caseStmt.elseBody.isEmpty)

assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditions.head.getText == "1 = 1")

assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SearchedCaseStatement])
val nestedCaseStmt =
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SearchedCaseStatement]

assert(nestedCaseStmt.conditions.length == 1)
assert(nestedCaseStmt.conditionalBodies.length == 1)
Expand Down Expand Up @@ -1616,11 +1616,15 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[CaseStatement])
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
assert(caseStmt.conditions.length == 1)
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1))
assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
assert(caseStmt.caseVariableExpression == Literal(1))
assert(caseStmt.conditionExpressions.length == 1)
assert(caseStmt.conditionExpressions.head == Literal(1))

assert(caseStmt.conditionalBodies.length == 1)
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 1")
}

test("simple case statement with empty body") {
Expand Down Expand Up @@ -1656,31 +1660,27 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]

assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[CaseStatement])
assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])

val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
assert(caseStmt.conditions.length == 3)
val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]
assert(caseStmt.caseVariableExpression == Literal(1))
assert(caseStmt.conditionExpressions.length == 3)
assert(caseStmt.conditionalBodies.length == 3)
assert(caseStmt.elseBody.isEmpty)

assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1))
assert(caseStmt.conditionExpressions.head == Literal(1))

assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 1")

assert(caseStmt.conditions(1).isInstanceOf[SingleStatement])
checkSimpleCaseStatementCondition(
caseStmt.conditions(1), _ == Literal(1), _.isInstanceOf[ScalarSubquery])
assert(caseStmt.conditionExpressions(1).isInstanceOf[ScalarSubquery])

assert(caseStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT * FROM b")

assert(caseStmt.conditions(2).isInstanceOf[SingleStatement])
checkSimpleCaseStatementCondition(
caseStmt.conditions(2), _ == Literal(1), _.isInstanceOf[In])
assert(caseStmt.conditionExpressions(2).isInstanceOf[In])

assert(caseStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement])
assert(caseStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement]
Expand All @@ -1701,12 +1701,17 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[CaseStatement])
val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])
val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]

assert(caseStmt.caseVariableExpression == Literal(1))
assert(caseStmt.elseBody.isDefined)
assert(caseStmt.conditions.length == 1)
assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
checkSimpleCaseStatementCondition(caseStmt.conditions.head, _ == Literal(1), _ == Literal(1))
assert(caseStmt.conditionExpressions.length == 1)
assert(caseStmt.conditionExpressions.head == Literal(1))

assert(caseStmt.conditionalBodies.length == 1)
assert(caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 42")

assert(caseStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
assert(caseStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
Expand All @@ -1730,28 +1735,27 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|""".stripMargin
val tree = parsePlan(sqlScriptText).asInstanceOf[CompoundBody]
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[CaseStatement])
assert(tree.collection.head.isInstanceOf[SimpleCaseStatement])

val caseStmt = tree.collection.head.asInstanceOf[CaseStatement]
assert(caseStmt.conditions.length == 1)
val caseStmt = tree.collection.head.asInstanceOf[SimpleCaseStatement]

assert(caseStmt.caseVariableExpression.isInstanceOf[ScalarSubquery])
assert(caseStmt.conditionExpressions.length == 1)
assert(caseStmt.conditionalBodies.length == 1)
assert(caseStmt.elseBody.isEmpty)

assert(caseStmt.conditions.head.isInstanceOf[SingleStatement])
checkSimpleCaseStatementCondition(
caseStmt.conditions.head, _.isInstanceOf[ScalarSubquery], _ == Literal(1))
assert(caseStmt.conditionExpressions.head == Literal(1))

assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[CaseStatement])
assert(caseStmt.conditionalBodies.head.collection.head.isInstanceOf[SimpleCaseStatement])
val nestedCaseStmt =
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[CaseStatement]
caseStmt.conditionalBodies.head.collection.head.asInstanceOf[SimpleCaseStatement]

assert(nestedCaseStmt.conditions.length == 1)
assert(nestedCaseStmt.caseVariableExpression == Literal(2))
assert(nestedCaseStmt.conditionExpressions.length == 1)
assert(nestedCaseStmt.conditionalBodies.length == 1)
assert(nestedCaseStmt.elseBody.isDefined)

assert(nestedCaseStmt.conditions.head.isInstanceOf[SingleStatement])
checkSimpleCaseStatementCondition(
nestedCaseStmt.conditions.head, _ == Literal(2), _ == Literal(2))
assert(nestedCaseStmt.conditionExpressions.head == Literal(2))

assert(nestedCaseStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(nestedCaseStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
Expand Down Expand Up @@ -2910,17 +2914,4 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
.replace("END", "")
.trim
}

private def checkSimpleCaseStatementCondition(
conditionStatement: SingleStatement,
predicateLeft: Expression => Boolean,
predicateRight: Expression => Boolean): Unit = {
assert(conditionStatement.parsedPlan.isInstanceOf[Project])
val project = conditionStatement.parsedPlan.asInstanceOf[Project]
assert(project.projectList.head.isInstanceOf[Alias])
assert(project.projectList.head.asInstanceOf[Alias].child.isInstanceOf[EqualTo])
val equalTo = project.projectList.head.asInstanceOf[Alias].child.asInstanceOf[EqualTo]
assert(predicateLeft(equalTo.left))
assert(predicateRight(equalTo.right))
}
}
Loading