From 8420c02f6c43f8d04873912ec8586eea934daefd Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 5 Dec 2017 17:40:49 +0000 Subject: [PATCH 1/8] initial commit --- .../expressions/conditionalExpressions.scala | 73 ++++++++++--------- .../expressions/nullExpressions.scala | 47 +++++++----- .../sql/catalyst/expressions/predicates.scala | 41 ++++++----- .../ConditionalExpressionSuite.scala | 7 ++ .../expressions/NullExpressionsSuite.scala | 7 ++ .../catalyst/expressions/PredicateSuite.scala | 8 +- 6 files changed, 109 insertions(+), 74 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ae5f7140847db..a9038b879e89a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -180,13 +180,13 @@ case class CaseWhen( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // This variable represents whether the first successful condition is met or not. - // It is initialized to `false` and it is set to `true` when the first condition which - // evaluates to `true` is met and therefore is not needed to go on anymore on the computation + // This variable represents whether the condition is met with true/false or not met. + // It is initialized to -1 and it is set to 0 or 1 when the condition which evaluates to + // `false` or `true` is met and therefore is not needed to go on anymore on the computation // of the following conditions. val conditionMet = ctx.freshName("caseWhenConditionMet") - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) + val value = ctx.freshName("value") + ctx.addMutableState(ctx.javaType(dataType), value) // these blocks are meant to be inside a // do { @@ -200,9 +200,8 @@ case class CaseWhen( |${cond.code} |if (!${cond.isNull} && ${cond.value}) { | ${res.code} - | ${ev.isNull} = ${res.isNull}; - | ${ev.value} = ${res.value}; - | $conditionMet = true; + | $conditionMet = (byte)(${res.isNull} ? 1 : 0); + | $value = ${res.value}; | continue; |} """.stripMargin @@ -212,8 +211,8 @@ case class CaseWhen( val res = elseExpr.genCode(ctx) s""" |${res.code} - |${ev.isNull} = ${res.isNull}; - |${ev.value} = ${res.value}; + |$conditionMet = (byte)(${res.isNull} ? 1 : 0); + |$value = ${res.value}; """.stripMargin } @@ -221,17 +220,17 @@ case class CaseWhen( // This generates code like: // conditionMet = caseWhen_1(i); - // if(conditionMet) { + // if(conditionMet != -1) { // continue; // } // conditionMet = caseWhen_2(i); - // if(conditionMet) { + // if(conditionMet != -1) { // continue; // } // ... // and the declared methods are: - // private boolean caseWhen_1234() { - // boolean conditionMet = false; + // private byte caseWhen_1234() { + // byte conditionMet = -1; // do { // // here the evaluation of the conditions // } while (false); @@ -240,31 +239,35 @@ case class CaseWhen( val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, funcName = "caseWhen", - returnType = ctx.JAVA_BOOLEAN, - makeSplitFunction = func => - s""" - |${ctx.JAVA_BOOLEAN} $conditionMet = false; - |do { - | $func - |} while (false); - |return $conditionMet; - """.stripMargin, - foldFunctions = _.map { funcCall => - s""" - |$conditionMet = $funcCall; - |if ($conditionMet) { - | continue; - |} - """.stripMargin - }.mkString) + returnType = ctx.JAVA_BYTE, + makeSplitFunction = { + func => + s""" + ${ctx.JAVA_BYTE} $conditionMet = -1; + do { + $func + } while (false); + return $conditionMet; + """ + }, + foldFunctions = { funcCalls => + funcCalls.map { funcCall => + s""" + $conditionMet = $funcCall; + if ($conditionMet != -1) { + continue; + }""" + }.mkString + }) ev.copy(code = s""" - ${ev.isNull} = true; - ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ctx.JAVA_BOOLEAN} $conditionMet = false; + ${ctx.JAVA_BYTE} $conditionMet = -1; + $value = ${ctx.defaultValue(dataType)}; do { $codes - } while (false);""") + } while (false); + boolean ${ev.isNull} = ($conditionMet != 0); // TRUE if -1 or 1 + ${ctx.javaType(dataType)} ${ev.value} = $value;""") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 26c9a41efc9f9..712d06ccb30cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) - ctx.addMutableState(ctx.javaType(dataType), ev.value) + val isNull = ctx.freshName("isNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -81,7 +81,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { s""" |${eval.code} |if (!${eval.isNull}) { - | ${ev.isNull} = false; + | $isNull = false; | ${ev.value} = ${eval.value}; | continue; |} @@ -91,29 +91,36 @@ case class Coalesce(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", - makeSplitFunction = func => - s""" - |do { - | $func - |} while (false); - """.stripMargin, - foldFunctions = _.map { funcCall => - s""" - |$funcCall; - |if (!${ev.isNull}) { - | continue; - |} - """.stripMargin - }.mkString) - + extraArguments = (ctx.javaType(dataType), ev.value) :: Nil, + returnType = ctx.javaType(dataType), + makeSplitFunction = { + func => + s""" + |do { + | $func + |} while (false); + |return ${ev.value}; + """.stripMargin + }, + foldFunctions = { funcCalls => + funcCalls.map { funcCall => + s""" + |${ev.value} = $funcCall; + |if (!$isNull) { + | continue; + |} + """.stripMargin + }.mkString + }) ev.copy(code = s""" - |${ev.isNull} = true; - |${ev.value} = ${ctx.defaultValue(dataType)}; + |$isNull = true; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); + |boolean ${ev.isNull} = $isNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 04e669492ec6d..c51cb7371df24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -237,8 +237,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val javaDataType = ctx.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.value) - ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull) + // inValue -1:isNull, 0:false, 1:true + val inValue = ctx.freshName("value") val valueArg = ctx.freshName("valueArg") // All the blocks are meant to be inside a do { ... } while (false); loop. // The evaluation of variables can be stopped when we find a matching value. @@ -246,10 +246,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { s""" |${x.code} |if (${x.isNull}) { - | ${ev.isNull} = true; + | $inValue = -1; // isNull = true |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - | ${ev.isNull} = false; - | ${ev.value} = true; + | $inValue = 1; // value = TRUE | continue; |} """.stripMargin) @@ -257,33 +256,39 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: Nil, - makeSplitFunction = body => + extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, inValue) :: Nil, + returnType = ctx.JAVA_BYTE, + makeSplitFunction = { body => s""" |do { | $body |} while (false); - """.stripMargin, - foldFunctions = _.map { funcCall => - s""" - |$funcCall; - |if (${ev.value}) { - | continue; - |} + |return $inValue; """.stripMargin - }.mkString("\n")) + }, + foldFunctions = { funcCalls => + funcCalls.map(funcCall => + s""" + |$inValue = $funcCall; + |if ($inValue == 1) { + | continue; + |} + """.stripMargin).mkString("\n") + } + ) ev.copy(code = s""" |${valueGen.code} - |${ev.value} = false; - |${ev.isNull} = ${valueGen.isNull}; - |if (!${ev.isNull}) { + |byte $inValue = (byte)(${valueGen.isNull} ? -1 : 0); // isNull or FALSE + |if ($inValue != -1) { | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes | } while (false); |} + |boolean ${ev.isNull} = ($inValue == -1); + |boolean ${ev.value} = ($inValue == 1); """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 3e11c3d2d4fe3..60d84aae1fa3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -145,4 +146,10 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper IndexedSeq((Literal(12) === Literal(1), Literal(42)), (Literal(12) === Literal(42), Literal(1)))) } + + test("SPARK-22705: case when should use less global variables") { + val ctx = new CodegenContext() + CaseWhen(Seq((Literal.create(false, BooleanType), Literal(1))), Literal(-1)).genCode(ctx) + assert(ctx.mutableStates.size == 1) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 40ef7770da33f..a23cd95632770 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.types._ @@ -155,6 +156,12 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(inputs), "x_1") } + test("SPARK-22705: Coalesce should use less global variables") { + val ctx = new CodegenContext() + Coalesce(Seq(Literal("a"), Literal("b"))).genCode(ctx) + assert(ctx.mutableStates.size == 1) + } + test("AtLeastNNonNulls should not throw 64kb exception") { val inputs = (1 to 4000).map(x => Literal(s"x_$x")) checkEvaluation(AtLeastNNonNulls(1, inputs), true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0079e4e8d6f74..e8b4ec1618964 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -245,6 +245,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(In(Literal(1.0D), sets), true) } + test("SPARK-22705: In should use less global variables") { + val ctx = new CodegenContext() + In(Literal(1.0D), Seq(Literal(1.0D), Literal(2.0D))).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } + test("INSET") { val hS = HashSet[Any]() + 1 + 2 val nS = HashSet[Any]() + 1 + 2 + null From bcbe82c870073163a68cb68994c2a1ed2f60705b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Tue, 5 Dec 2017 19:44:42 +0000 Subject: [PATCH 2/8] fix scala style error --- .../apache/spark/sql/catalyst/expressions/PredicateSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index e8b4ec1618964..c85d24dd245d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow From c81d795565ae347243c3571885820b5ac15bf494 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 6 Dec 2017 07:51:23 +0000 Subject: [PATCH 3/8] address review comment --- .../expressions/conditionalExpressions.scala | 75 ++++++++++--------- .../expressions/nullExpressions.scala | 14 ++-- .../sql/catalyst/expressions/predicates.scala | 26 ++++--- 3 files changed, 61 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index a9038b879e89a..fb88ede79d7fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -180,13 +180,15 @@ case class CaseWhen( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // This variable represents whether the condition is met with true/false or not met. - // It is initialized to -1 and it is set to 0 or 1 when the condition which evaluates to - // `false` or `true` is met and therefore is not needed to go on anymore on the computation - // of the following conditions. - val conditionMet = ctx.freshName("caseWhenConditionMet") - val value = ctx.freshName("value") - ctx.addMutableState(ctx.javaType(dataType), value) + // This variable represents whether the evaluated result is null or not. It's a byte value + // instead of boolean because it carries an extra information about if the case-when condition + // is met or not. It is initialized to `-1`, which means the condition is not met yet and the + // result is unknown. When the first condition is met, it is set to `1` if result is null, or + // `0` if result is not null. We won't go on anymore on the computation if it's set to `1` or + // `0`. + val resultIsNull = ctx.freshName("caseWhenResultIsNull") + val tmpResult = ctx.freshName("caseWhenTmpResult") + ctx.addMutableState(ctx.javaType(dataType), tmpResult) // these blocks are meant to be inside a // do { @@ -200,8 +202,8 @@ case class CaseWhen( |${cond.code} |if (!${cond.isNull} && ${cond.value}) { | ${res.code} - | $conditionMet = (byte)(${res.isNull} ? 1 : 0); - | $value = ${res.value}; + | $resultIsNull = (byte)(${res.isNull} ? 1 : 0); + | $tmpResult = ${res.value}; | continue; |} """.stripMargin @@ -211,30 +213,30 @@ case class CaseWhen( val res = elseExpr.genCode(ctx) s""" |${res.code} - |$conditionMet = (byte)(${res.isNull} ? 1 : 0); - |$value = ${res.value}; + |$resultIsNull = (byte)(${res.isNull} ? 1 : 0); + |$tmpResult = ${res.value}; """.stripMargin } val allConditions = cases ++ elseCode // This generates code like: - // conditionMet = caseWhen_1(i); - // if(conditionMet != -1) { + // caseWhenResultIsNull = caseWhen_1(i); + // if(caseWhenResultIsNull != -1) { // continue; // } - // conditionMet = caseWhen_2(i); - // if(conditionMet != -1) { + // caseWhenResultIsNull = caseWhen_2(i); + // if(caseWhenResultIsNull != -1) { // continue; // } // ... // and the declared methods are: // private byte caseWhen_1234() { - // byte conditionMet = -1; + // byte caseWhenResultIsNull = -1; // do { // // here the evaluation of the conditions // } while (false); - // return conditionMet; + // return caseWhenResultIsNull; // } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, @@ -243,31 +245,34 @@ case class CaseWhen( makeSplitFunction = { func => s""" - ${ctx.JAVA_BYTE} $conditionMet = -1; - do { - $func - } while (false); - return $conditionMet; - """ + |${ctx.JAVA_BYTE} $resultIsNull = -1; + |do { + | $func + |} while (false); + |return $resultIsNull; + """.stripMargin }, foldFunctions = { funcCalls => funcCalls.map { funcCall => s""" - $conditionMet = $funcCall; - if ($conditionMet != -1) { - continue; - }""" + |$resultIsNull = $funcCall; + |if ($resultIsNull != -1) { + | continue; + |} + """.stripMargin }.mkString }) - ev.copy(code = s""" - ${ctx.JAVA_BYTE} $conditionMet = -1; - $value = ${ctx.defaultValue(dataType)}; - do { - $codes - } while (false); - boolean ${ev.isNull} = ($conditionMet != 0); // TRUE if -1 or 1 - ${ctx.javaType(dataType)} ${ev.value} = $value;""") + ev.copy(code = + s""" + |${ctx.JAVA_BYTE} $resultIsNull = -1; + |$tmpResult = ${ctx.defaultValue(dataType)}; + |do { + | $codes + |} while (false); + |boolean ${ev.isNull} = ($resultIsNull != 0); // TRUE if -1 or 1 + |${ctx.javaType(dataType)} ${ev.value} = $tmpResult; + """) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 712d06ccb30cb..735ed48bdc5b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val isNull = ctx.freshName("isNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, isNull) + val coalesceTmpIsNull = ctx.freshName("coalesceTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, coalesceTmpIsNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -81,7 +81,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { s""" |${eval.code} |if (!${eval.isNull}) { - | $isNull = false; + | $coalesceTmpIsNull = false; | ${ev.value} = ${eval.value}; | continue; |} @@ -91,11 +91,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", - extraArguments = (ctx.javaType(dataType), ev.value) :: Nil, returnType = ctx.javaType(dataType), makeSplitFunction = { func => s""" + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $func |} while (false); @@ -106,7 +106,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { funcCalls.map { funcCall => s""" |${ev.value} = $funcCall; - |if (!$isNull) { + |if (!$coalesceTmpIsNull) { | continue; |} """.stripMargin @@ -115,12 +115,12 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" - |$isNull = true; + |$coalesceTmpIsNull = true; |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); - |boolean ${ev.isNull} = $isNull; + |boolean ${ev.isNull} = $coalesceTmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c51cb7371df24..66b9c487960be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -237,8 +237,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val javaDataType = ctx.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) - // inValue -1:isNull, 0:false, 1:true - val inValue = ctx.freshName("value") + // inTmpResult -1 indicates at lease one expr in list is evaluated to null. + // 0 means no matches found. 1 means the expr in list matches the given value expr. + val inTmpResult = ctx.freshName("inTmpResult") val valueArg = ctx.freshName("valueArg") // All the blocks are meant to be inside a do { ... } while (false); loop. // The evaluation of variables can be stopped when we find a matching value. @@ -246,9 +247,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { s""" |${x.code} |if (${x.isNull}) { - | $inValue = -1; // isNull = true + | $inTmpResult = -1; // isNull = true |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - | $inValue = 1; // value = TRUE + | $inTmpResult = 1; // value = TRUE | continue; |} """.stripMargin) @@ -256,21 +257,21 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, inValue) :: Nil, + extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, inTmpResult) :: Nil, returnType = ctx.JAVA_BYTE, makeSplitFunction = { body => s""" |do { | $body |} while (false); - |return $inValue; + |return $inTmpResult; """.stripMargin }, foldFunctions = { funcCalls => funcCalls.map(funcCall => s""" - |$inValue = $funcCall; - |if ($inValue == 1) { + |$inTmpResult = $funcCall; + |if ($inTmpResult == 1) { | continue; |} """.stripMargin).mkString("\n") @@ -280,15 +281,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { ev.copy(code = s""" |${valueGen.code} - |byte $inValue = (byte)(${valueGen.isNull} ? -1 : 0); // isNull or FALSE - |if ($inValue != -1) { + |// TRUE if any condition is met and the result is not null, or no any condition is met. + |byte $inTmpResult = (byte)(${valueGen.isNull} ? -1 : 0); + |if ($inTmpResult != -1) { | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes | } while (false); |} - |boolean ${ev.isNull} = ($inValue == -1); - |boolean ${ev.value} = ($inValue == 1); + |boolean ${ev.isNull} = ($inTmpResult == -1); + |boolean ${ev.value} = ($inTmpResult == 1); """.stripMargin) } From 96183d390a4b7887736e66a1ab6616f052ee9f36 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 6 Dec 2017 09:27:55 +0000 Subject: [PATCH 4/8] address review comments --- .../expressions/conditionalExpressions.scala | 41 ++++++++-------- .../expressions/nullExpressions.scala | 47 +++++++++---------- .../sql/catalyst/expressions/predicates.scala | 40 +++++++--------- 3 files changed, 59 insertions(+), 69 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index fb88ede79d7fe..6c4a8b1e54467 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -242,26 +242,22 @@ case class CaseWhen( expressions = allConditions, funcName = "caseWhen", returnType = ctx.JAVA_BYTE, - makeSplitFunction = { - func => - s""" - |${ctx.JAVA_BYTE} $resultIsNull = -1; - |do { - | $func - |} while (false); - |return $resultIsNull; - """.stripMargin - }, - foldFunctions = { funcCalls => - funcCalls.map { funcCall => - s""" - |$resultIsNull = $funcCall; - |if ($resultIsNull != -1) { - | continue; - |} - """.stripMargin - }.mkString - }) + makeSplitFunction = func => + s""" + |${ctx.JAVA_BYTE} $resultIsNull = -1; + |do { + | $func + |} while (false); + |return $resultIsNull; + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$resultIsNull = $funcCall; + |if ($resultIsNull != -1) { + | continue; + |} + """.stripMargin + }.mkString) ev.copy(code = s""" @@ -270,8 +266,9 @@ case class CaseWhen( |do { | $codes |} while (false); - |boolean ${ev.isNull} = ($resultIsNull != 0); // TRUE if -1 or 1 - |${ctx.javaType(dataType)} ${ev.value} = $tmpResult; + |// TRUE if any condition is met and the result is not null, or no any condition is met. + |final boolean ${ev.isNull} = ($resultIsNull != 0); + |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; """) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 735ed48bdc5b6..d00ce26a18e57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -72,8 +72,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val coalesceTmpIsNull = ctx.freshName("coalesceTmpIsNull") - ctx.addMutableState(ctx.JAVA_BOOLEAN, coalesceTmpIsNull) + val tmpIsNull = ctx.freshName("coalesceTmpIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull) // all the evals are meant to be in a do { ... } while (false); loop val evals = children.map { e => @@ -81,7 +81,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { s""" |${eval.code} |if (!${eval.isNull}) { - | $coalesceTmpIsNull = false; + | $tmpIsNull = false; | ${ev.value} = ${eval.value}; | continue; |} @@ -92,35 +92,32 @@ case class Coalesce(children: Seq[Expression]) extends Expression { expressions = evals, funcName = "coalesce", returnType = ctx.javaType(dataType), - makeSplitFunction = { - func => - s""" - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - |do { - | $func - |} while (false); - |return ${ev.value}; - """.stripMargin - }, - foldFunctions = { funcCalls => - funcCalls.map { funcCall => - s""" - |${ev.value} = $funcCall; - |if (!$coalesceTmpIsNull) { - | continue; - |} - """.stripMargin - }.mkString - }) + makeSplitFunction = func => + s""" + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |do { + | $func + |} while (false); + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |${ev.value} = $funcCall; + |if (!$tmpIsNull) { + | continue; + |} + """.stripMargin + }.mkString) + ev.copy(code = s""" - |$coalesceTmpIsNull = true; + |$tmpIsNull = true; |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); - |boolean ${ev.isNull} = $coalesceTmpIsNull; + |final boolean ${ev.isNull} = $tmpIsNull; """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 66b9c487960be..c43d2e54db320 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -239,7 +239,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val listGen = list.map(_.genCode(ctx)) // inTmpResult -1 indicates at lease one expr in list is evaluated to null. // 0 means no matches found. 1 means the expr in list matches the given value expr. - val inTmpResult = ctx.freshName("inTmpResult") + val tmpResult = ctx.freshName("inTmpResult") val valueArg = ctx.freshName("valueArg") // All the blocks are meant to be inside a do { ... } while (false); loop. // The evaluation of variables can be stopped when we find a matching value. @@ -247,9 +247,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { s""" |${x.code} |if (${x.isNull}) { - | $inTmpResult = -1; // isNull = true + | $tmpResult = -1; // isNull = true |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - | $inTmpResult = 1; // value = TRUE + | $tmpResult = 1; // value = TRUE | continue; |} """.stripMargin) @@ -257,40 +257,36 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val codes = ctx.splitExpressionsWithCurrentInputs( expressions = listCode, funcName = "valueIn", - extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, inTmpResult) :: Nil, + extraArguments = (javaDataType, valueArg) :: (ctx.JAVA_BYTE, tmpResult) :: Nil, returnType = ctx.JAVA_BYTE, - makeSplitFunction = { body => + makeSplitFunction = body => s""" |do { | $body |} while (false); - |return $inTmpResult; + |return $tmpResult; + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$tmpResult = $funcCall; + |if ($tmpResult == 1) { + | continue; + |} """.stripMargin - }, - foldFunctions = { funcCalls => - funcCalls.map(funcCall => - s""" - |$inTmpResult = $funcCall; - |if ($inTmpResult == 1) { - | continue; - |} - """.stripMargin).mkString("\n") - } - ) + }.mkString("\n")) ev.copy(code = s""" |${valueGen.code} - |// TRUE if any condition is met and the result is not null, or no any condition is met. - |byte $inTmpResult = (byte)(${valueGen.isNull} ? -1 : 0); - |if ($inTmpResult != -1) { + |byte $tmpResult = (byte)(${valueGen.isNull} ? -1 : 0); + |if ($tmpResult != -1) { | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes | } while (false); |} - |boolean ${ev.isNull} = ($inTmpResult == -1); - |boolean ${ev.value} = ($inTmpResult == 1); + |final boolean ${ev.isNull} = ($tmpResult == -1); + |final boolean ${ev.value} = ($tmpResult == 1); """.stripMargin) } From 6a14e160892393168bd0bce0f041c52c7044e869 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 6 Dec 2017 13:26:13 +0000 Subject: [PATCH 5/8] address review comments --- .../expressions/conditionalExpressions.scala | 20 +++++++++---------- .../expressions/nullExpressions.scala | 7 ++++--- .../sql/catalyst/expressions/predicates.scala | 1 + 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 6c4a8b1e54467..1fd28f8a2c3c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -186,7 +186,7 @@ case class CaseWhen( // result is unknown. When the first condition is met, it is set to `1` if result is null, or // `0` if result is not null. We won't go on anymore on the computation if it's set to `1` or // `0`. - val resultIsNull = ctx.freshName("caseWhenResultIsNull") + val resultState = ctx.freshName("caseWhenResultState") val tmpResult = ctx.freshName("caseWhenTmpResult") ctx.addMutableState(ctx.javaType(dataType), tmpResult) @@ -202,7 +202,7 @@ case class CaseWhen( |${cond.code} |if (!${cond.isNull} && ${cond.value}) { | ${res.code} - | $resultIsNull = (byte)(${res.isNull} ? 1 : 0); + | $resultState = (byte)(${res.isNull} ? 1 : 0); | $tmpResult = ${res.value}; | continue; |} @@ -213,7 +213,7 @@ case class CaseWhen( val res = elseExpr.genCode(ctx) s""" |${res.code} - |$resultIsNull = (byte)(${res.isNull} ? 1 : 0); + |$resultState = (byte)(${res.isNull} ? 1 : 0); |$tmpResult = ${res.value}; """.stripMargin } @@ -244,16 +244,16 @@ case class CaseWhen( returnType = ctx.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BYTE} $resultIsNull = -1; + |${ctx.JAVA_BYTE} $resultState = -1; |do { | $func |} while (false); - |return $resultIsNull; + |return $resultState; """.stripMargin, foldFunctions = _.map { funcCall => s""" - |$resultIsNull = $funcCall; - |if ($resultIsNull != -1) { + |$resultState = $funcCall; + |if ($resultState != -1) { | continue; |} """.stripMargin @@ -261,15 +261,15 @@ case class CaseWhen( ev.copy(code = s""" - |${ctx.JAVA_BYTE} $resultIsNull = -1; + |${ctx.JAVA_BYTE} $resultState = -1; |$tmpResult = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); |// TRUE if any condition is met and the result is not null, or no any condition is met. - |final boolean ${ev.isNull} = ($resultIsNull != 0); + |final boolean ${ev.isNull} = ($resultState != 0); |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; - """) + """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index d00ce26a18e57..294cdcb2e9546 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -88,13 +88,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """.stripMargin } + val resultType = ctx.javaType(dataType) val codes = ctx.splitExpressionsWithCurrentInputs( expressions = evals, funcName = "coalesce", - returnType = ctx.javaType(dataType), + returnType = resultType, makeSplitFunction = func => s""" - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $func |} while (false); @@ -113,7 +114,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" |$tmpIsNull = true; - |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |$resultType ${ev.value} = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c43d2e54db320..cab12e9dc2066 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -245,6 +245,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { // The evaluation of variables can be stopped when we find a matching value. val listCode = listGen.map(x => s""" + |$tmpResult = 0; |${x.code} |if (${x.isNull}) { | $tmpResult = -1; // isNull = true From 740f1a0bd3ed0995dc6ec9946f06b884a18cd77b Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 6 Dec 2017 14:38:58 +0000 Subject: [PATCH 6/8] address review comments --- .../expressions/conditionalExpressions.scala | 24 +++++++++---------- .../sql/catalyst/expressions/predicates.scala | 12 ++++++---- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1fd28f8a2c3c8..1624d8cf24024 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -180,12 +180,12 @@ case class CaseWhen( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - // This variable represents whether the evaluated result is null or not. It's a byte value - // instead of boolean because it carries an extra information about if the case-when condition - // is met or not. It is initialized to `-1`, which means the condition is not met yet and the - // result is unknown. When the first condition is met, it is set to `1` if result is null, or - // `0` if result is not null. We won't go on anymore on the computation if it's set to `1` or - // `0`. + // This variable holds the state of the result: + // -1 means the condition is not met yet and the result is unknown. + // 0 means the condition is met and result is not null. + // 1 means the condition is met and result is null. + // It is initialized to `-1`, and if it's set to `1` or `0`, We won't go on anymore on the + // computation. val resultState = ctx.freshName("caseWhenResultState") val tmpResult = ctx.freshName("caseWhenTmpResult") ctx.addMutableState(ctx.javaType(dataType), tmpResult) @@ -221,22 +221,22 @@ case class CaseWhen( val allConditions = cases ++ elseCode // This generates code like: - // caseWhenResultIsNull = caseWhen_1(i); - // if(caseWhenResultIsNull != -1) { + // caseWhenResultState = caseWhen_1(i); + // if(caseWhenResultState != -1) { // continue; // } - // caseWhenResultIsNull = caseWhen_2(i); - // if(caseWhenResultIsNull != -1) { + // caseWhenResultState = caseWhen_2(i); + // if(caseWhenResultState != -1) { // continue; // } // ... // and the declared methods are: // private byte caseWhen_1234() { - // byte caseWhenResultIsNull = -1; + // byte caseWhenResultState = -1; // do { // // here the evaluation of the conditions // } while (false); - // return caseWhenResultIsNull; + // return caseWhenResultState; // } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = allConditions, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index cab12e9dc2066..9398903183740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -237,15 +237,16 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val javaDataType = ctx.javaType(value.dataType) val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) - // inTmpResult -1 indicates at lease one expr in list is evaluated to null. - // 0 means no matches found. 1 means the expr in list matches the given value expr. + // inTmpResult has 3 possible values: + // -1 means no matches found and there is at least one value in the list evaluated to null + // 0 means no matches found and all values in the list are not null + // 1 means one value in the list is matched val tmpResult = ctx.freshName("inTmpResult") val valueArg = ctx.freshName("valueArg") // All the blocks are meant to be inside a do { ... } while (false); loop. // The evaluation of variables can be stopped when we find a matching value. val listCode = listGen.map(x => s""" - |$tmpResult = 0; |${x.code} |if (${x.isNull}) { | $tmpResult = -1; // isNull = true @@ -279,8 +280,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { ev.copy(code = s""" |${valueGen.code} - |byte $tmpResult = (byte)(${valueGen.isNull} ? -1 : 0); - |if ($tmpResult != -1) { + |byte $tmpResult = -1; + |if (!${valueGen.isNull}) { + | $tmpResult = 0; | $javaDataType $valueArg = ${valueGen.value}; | do { | $codes From 31ab853c6d496ac5d6a0a4c4e6b9cf03e7cbf466 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Wed, 6 Dec 2017 23:56:35 +0000 Subject: [PATCH 7/8] address review comments --- .../expressions/conditionalExpressions.scala | 21 +++++++++++-------- .../sql/catalyst/expressions/predicates.scala | 21 +++++++++++-------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1624d8cf24024..1283b0cc02272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -181,9 +181,12 @@ case class CaseWhen( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // This variable holds the state of the result: - // -1 means the condition is not met yet and the result is unknown. - // 0 means the condition is met and result is not null. - // 1 means the condition is met and result is null. + // -1 means the condition is not met yet and the result is unknown. + val NOT_MATCHED = -1 + // 0 means the condition is met and result is not null. + val HAS_NONNULL = 0 + // 1 means the condition is met and result is null. + val HAS_NULL = 1 // It is initialized to `-1`, and if it's set to `1` or `0`, We won't go on anymore on the // computation. val resultState = ctx.freshName("caseWhenResultState") @@ -202,7 +205,7 @@ case class CaseWhen( |${cond.code} |if (!${cond.isNull} && ${cond.value}) { | ${res.code} - | $resultState = (byte)(${res.isNull} ? 1 : 0); + | $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); | $tmpResult = ${res.value}; | continue; |} @@ -213,7 +216,7 @@ case class CaseWhen( val res = elseExpr.genCode(ctx) s""" |${res.code} - |$resultState = (byte)(${res.isNull} ? 1 : 0); + |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL); |$tmpResult = ${res.value}; """.stripMargin } @@ -244,7 +247,7 @@ case class CaseWhen( returnType = ctx.JAVA_BYTE, makeSplitFunction = func => s""" - |${ctx.JAVA_BYTE} $resultState = -1; + |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $func |} while (false); @@ -253,7 +256,7 @@ case class CaseWhen( foldFunctions = _.map { funcCall => s""" |$resultState = $funcCall; - |if ($resultState != -1) { + |if ($resultState != $NOT_MATCHED) { | continue; |} """.stripMargin @@ -261,13 +264,13 @@ case class CaseWhen( ev.copy(code = s""" - |${ctx.JAVA_BYTE} $resultState = -1; + |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED; |$tmpResult = ${ctx.defaultValue(dataType)}; |do { | $codes |} while (false); |// TRUE if any condition is met and the result is not null, or no any condition is met. - |final boolean ${ev.isNull} = ($resultState != 0); + |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL); |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 9398903183740..dfafd4c86fa01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -238,9 +238,12 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: - // -1 means no matches found and there is at least one value in the list evaluated to null - // 0 means no matches found and all values in the list are not null - // 1 means one value in the list is matched + // -1 means no matches found and there is at least one value in the list evaluated + val HAS_NULL = -1 + // 0 means no matches found and all values in the list are not null + val NOT_MATCHED = 0 + // 1 means one value in the list is matched + val MATCHED = 1 val tmpResult = ctx.freshName("inTmpResult") val valueArg = ctx.freshName("valueArg") // All the blocks are meant to be inside a do { ... } while (false); loop. @@ -249,9 +252,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { s""" |${x.code} |if (${x.isNull}) { - | $tmpResult = -1; // isNull = true + | $tmpResult = $HAS_NULL; // ${ev.isNull} = true; |} else if (${ctx.genEqual(value.dataType, valueArg, x.value)}) { - | $tmpResult = 1; // value = TRUE + | $tmpResult = $MATCHED; // ${ev.isNull} = false; ${ev.value} = true; | continue; |} """.stripMargin) @@ -271,7 +274,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { foldFunctions = _.map { funcCall => s""" |$tmpResult = $funcCall; - |if ($tmpResult == 1) { + |if ($tmpResult == $MATCHED) { | continue; |} """.stripMargin @@ -280,7 +283,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { ev.copy(code = s""" |${valueGen.code} - |byte $tmpResult = -1; + |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { | $tmpResult = 0; | $javaDataType $valueArg = ${valueGen.value}; @@ -288,8 +291,8 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { | $codes | } while (false); |} - |final boolean ${ev.isNull} = ($tmpResult == -1); - |final boolean ${ev.value} = ($tmpResult == 1); + |final boolean ${ev.isNull} = ($tmpResult == $HAS_NULL); + |final boolean ${ev.value} = ($tmpResult == $MATCHED); """.stripMargin) } From c4691b6b4cd8b2ac26dd9c89243372dfec5bb913 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 7 Dec 2017 01:42:00 +0000 Subject: [PATCH 8/8] address review comments --- .../sql/catalyst/expressions/conditionalExpressions.scala | 6 +++--- .../apache/spark/sql/catalyst/expressions/predicates.scala | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 1283b0cc02272..53c3b226895ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -187,8 +187,8 @@ case class CaseWhen( val HAS_NONNULL = 0 // 1 means the condition is met and result is null. val HAS_NULL = 1 - // It is initialized to `-1`, and if it's set to `1` or `0`, We won't go on anymore on the - // computation. + // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`, + // We won't go on anymore on the computation. val resultState = ctx.freshName("caseWhenResultState") val tmpResult = ctx.freshName("caseWhenTmpResult") ctx.addMutableState(ctx.javaType(dataType), tmpResult) @@ -269,7 +269,7 @@ case class CaseWhen( |do { | $codes |} while (false); - |// TRUE if any condition is met and the result is not null, or no any condition is met. + |// TRUE if any condition is met and the result is null, or no any condition is met. |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL); |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult; """.stripMargin) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index dfafd4c86fa01..7445b657f9882 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -238,7 +238,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) // inTmpResult has 3 possible values: - // -1 means no matches found and there is at least one value in the list evaluated + // -1 means no matches found and there is at least one value in the list evaluated to null val HAS_NULL = -1 // 0 means no matches found and all values in the list are not null val NOT_MATCHED = 0