Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ compoundStatement
: statement
| setStatementWithOptionalVarKeyword
| beginEndCompoundBlock
| ifElseStatement
;

setStatementWithOptionalVarKeyword
Expand All @@ -71,6 +72,12 @@ setStatementWithOptionalVarKeyword
LEFT_PAREN query RIGHT_PAREN #setVariableWithOptionalKeyword
;

ifElseStatement
: IF booleanExpression THEN conditionalBodies+=compoundBody
(ELSE IF booleanExpression THEN conditionalBodies+=compoundBody)*
(ELSE elseBody=compoundBody)? END IF
;

singleStatement
: (statement|setResetStatement) SEMICOLON* EOF
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,23 @@ class AstBuilder extends DataTypeAstBuilder
.map { s =>
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
}.getOrElse {
visit(ctx.beginEndCompoundBlock()).asInstanceOf[CompoundPlanStatement]
visitChildren(ctx).asInstanceOf[CompoundPlanStatement]
}
}

override def visitIfElseStatement(ctx: IfElseStatementContext): IfElseStatement = {
IfElseStatement(
conditions = ctx.booleanExpression().asScala.toList.map(boolExpr => withOrigin(boolExpr) {
SingleStatement(
Project(
Seq(Alias(expression(boolExpr), "condition")()),
OneRowRelation()))
}),
conditionalBodies = ctx.conditionalBodies.asScala.toList.map(body => visitCompoundBody(body)),
elseBody = Option(ctx.elseBody).map(body => visitCompoundBody(body))
)
}

override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setResetStatement().asInstanceOf[ParserRuleContext]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,19 @@ case class SingleStatement(parsedPlan: LogicalPlan)
case class CompoundBody(
collection: Seq[CompoundPlanStatement],
label: Option[String]) extends CompoundPlanStatement

/**
* Logical operator for IF ELSE statement.
* @param conditions Collection of conditions. First condition corresponds to IF clause,
* while others (if any) correspond to following ELSE IF clauses.
* @param conditionalBodies Collection of bodies that have a corresponding condition,
* in IF or ELSE IF branches.
* @param elseBody Body that is executed if none of the conditions are met,
* i.e. ELSE branch.
*/
case class IfElseStatement(
conditions: Seq[SingleStatement],
conditionalBodies: Seq[CompoundBody],
elseBody: Option[CompoundBody]) extends CompoundPlanStatement {
assert(conditions.length == conditionalBodies.length)
}
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,184 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
assert(e.getMessage.contains("Syntax error"))
}

test("if") {
val sqlScriptText =
"""
|BEGIN
| IF 1=1 THEN
| SELECT 42;
| END IF;
|END
|""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])
val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 1)
assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1=1")
}

test("if else") {
val sqlScriptText =
"""BEGIN
|IF 1 = 1 THEN
| SELECT 1;
|ELSE
| SELECT 2;
|END IF;
|END
""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])

val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 1)
assert(ifStmt.conditionalBodies.length == 1)
assert(ifStmt.elseBody.isDefined)

assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1 = 1")

assert(ifStmt.conditionalBodies.head.collection.length == 1)
assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 1")

assert(ifStmt.elseBody.get.collection.length == 1)
assert(ifStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 2")
}

test("if else if") {
val sqlScriptText =
"""BEGIN
|IF 1 = 1 THEN
| SELECT 1;
|ELSE IF 2 = 2 THEN
| SELECT 2;
|ELSE
| SELECT 3;
|END IF;
|END
""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])

val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 2)
assert(ifStmt.conditionalBodies.length == 2)
assert(ifStmt.elseBody.isDefined)

assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1 = 1")

assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 1")

assert(ifStmt.conditions(1).isInstanceOf[SingleStatement])
assert(ifStmt.conditions(1).getText == "2 = 2")

assert(ifStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 2")

assert(ifStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 3")
}

test("if multi else if") {
val sqlScriptText =
"""BEGIN
|IF 1 = 1 THEN
| SELECT 1;
|ELSE IF 2 = 2 THEN
| SELECT 2;
|ELSE IF 3 = 3 THEN
| SELECT 3;
|END IF;
|END
""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])

val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 3)
assert(ifStmt.conditionalBodies.length == 3)
assert(ifStmt.elseBody.isEmpty)

assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1 = 1")

assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 1")

assert(ifStmt.conditions(1).isInstanceOf[SingleStatement])
assert(ifStmt.conditions(1).getText == "2 = 2")

assert(ifStmt.conditionalBodies(1).collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies(1).collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 2")

assert(ifStmt.conditions(2).isInstanceOf[SingleStatement])
assert(ifStmt.conditions(2).getText == "3 = 3")

assert(ifStmt.conditionalBodies(2).collection.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditionalBodies(2).collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 3")
}

test("if nested") {
val sqlScriptText =
"""
|BEGIN
| IF 1=1 THEN
| IF 2=1 THEN
| SELECT 41;
| ELSE
| SELECT 42;
| END IF;
| END IF;
|END
|""".stripMargin
val tree = parseScript(sqlScriptText)
assert(tree.collection.length == 1)
assert(tree.collection.head.isInstanceOf[IfElseStatement])

val ifStmt = tree.collection.head.asInstanceOf[IfElseStatement]
assert(ifStmt.conditions.length == 1)
assert(ifStmt.conditionalBodies.length == 1)
assert(ifStmt.elseBody.isEmpty)

assert(ifStmt.conditions.head.isInstanceOf[SingleStatement])
assert(ifStmt.conditions.head.getText == "1=1")

assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[IfElseStatement])
val nestedIfStmt = ifStmt.conditionalBodies.head.collection.head.asInstanceOf[IfElseStatement]

assert(nestedIfStmt.conditions.length == 1)
assert(nestedIfStmt.conditionalBodies.length == 1)
assert(nestedIfStmt.elseBody.isDefined)

assert(nestedIfStmt.conditions.head.isInstanceOf[SingleStatement])
assert(nestedIfStmt.conditions.head.getText == "2=1")

assert(nestedIfStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement])
assert(nestedIfStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 41")

assert(nestedIfStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement])
assert(nestedIfStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement]
.getText == "SELECT 42")
}

// Helper methods
def cleanupStatementString(statementStr: String): String = {
statementStr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package org.apache.spark.sql.scripting

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.{Origin, WithOrigin}
import org.apache.spark.sql.types.BooleanType

/**
* Trait for all SQL scripting execution nodes used during interpretation phase.
Expand Down Expand Up @@ -55,6 +57,33 @@ trait NonLeafStatementExec extends CompoundStatementExec {
* Tree iterator.
*/
def getTreeIterator: Iterator[CompoundStatementExec]

/**
* Evaluate the boolean condition represented by the statement.
* @param session SparkSession that SQL script is executed within.
* @param statement Statement representing the boolean condition to evaluate.
* @return Whether the condition evaluates to True.
*/
protected def evaluateBooleanCondition(
session: SparkSession,
statement: LeafStatementExec): Boolean = statement match {
case statement: SingleStatementExec =>
assert(!statement.isExecuted)
statement.isExecuted = true

// DataFrame evaluates to True if it is single row, single column
// of boolean type with value True.
val df = Dataset.ofRows(session, statement.parsedPlan)
df.schema.fields match {
case Array(field) if field.dataType == BooleanType =>
df.limit(2).collect() match {
case Array(row) => row.getBoolean(0)
case _ => false
}
case _ => false
}
case _ => false
}
}

/**
Expand Down Expand Up @@ -155,3 +184,79 @@ abstract class CompoundNestedStatementIteratorExec(collection: Seq[CompoundState
*/
class CompoundBodyExec(statements: Seq[CompoundStatementExec])
extends CompoundNestedStatementIteratorExec(statements)

/**
* Executable node for IfElseStatement.
* @param conditions Collection of executable conditions. First condition corresponds to IF clause,
* while others (if any) correspond to following ELSE IF clauses.
* @param conditionalBodies Collection of executable bodies that have a corresponding condition,
* in IF or ELSE IF branches.
* @param elseBody Body that is executed if none of the conditions are met,
* i.e. ELSE branch.
* @param session Spark session that SQL script is executed within.
*/
class IfElseStatementExec(
conditions: Seq[SingleStatementExec],
conditionalBodies: Seq[CompoundBodyExec],
elseBody: Option[CompoundBodyExec],
session: SparkSession) extends NonLeafStatementExec {
private object IfElseState extends Enumeration {
val Condition, Body = Value
}

private var state = IfElseState.Condition
private var curr: Option[CompoundStatementExec] = Some(conditions.head)

private var clauseIdx: Int = 0
private val conditionsCount = conditions.length
assert(conditionsCount == conditionalBodies.length)

private lazy val treeIterator: Iterator[CompoundStatementExec] =
new Iterator[CompoundStatementExec] {
override def hasNext: Boolean = curr.nonEmpty

override def next(): CompoundStatementExec = state match {
case IfElseState.Condition =>
assert(curr.get.isInstanceOf[SingleStatementExec])
val condition = curr.get.asInstanceOf[SingleStatementExec]
if (evaluateBooleanCondition(session, condition)) {
state = IfElseState.Body
curr = Some(conditionalBodies(clauseIdx))
} else {
clauseIdx += 1
if (clauseIdx < conditionsCount) {
// There are ELSE IF clauses remaining.
state = IfElseState.Condition
curr = Some(conditions(clauseIdx))
} else if (elseBody.isDefined) {
// ELSE clause exists.
state = IfElseState.Body
curr = Some(elseBody.get)
} else {
// No remaining clauses.
curr = None
}
}
condition
case IfElseState.Body =>
assert(curr.get.isInstanceOf[CompoundBodyExec])
val currBody = curr.get.asInstanceOf[CompoundBodyExec]
val retStmt = currBody.getTreeIterator.next()
if (!currBody.getTreeIterator.hasNext) {
curr = None
}
retStmt
}
}

override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator

override def reset(): Unit = {
state = IfElseState.Condition
curr = Some(conditions.head)
clauseIdx = 0
conditions.foreach(c => c.reset())
conditionalBodies.foreach(b => b.reset())
elseBody.foreach(b => b.reset())
}
}
Loading