New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-14577][SQL] Add spark.sql.codegen.maxCaseBranches config option #12353
Changes from 3 commits
25ca987
b97d3ea
a9294bd
502bc61
9a2340c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,18 +81,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi | |
} | ||
|
||
/** | ||
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". | ||
* When a = true, returns b; when c = true, returns d; else returns e. | ||
* Abstract parent class for common logic in CaseWhen and CaseWhenCodegen. | ||
* | ||
* @param branches seq of (branch condition, branch value) | ||
* @param elseValue optional value for the else branch | ||
*/ | ||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") | ||
// scalastyle:on line.size.limit | ||
case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) | ||
extends Expression with CodegenFallback { | ||
abstract class CaseWhenBase( | ||
branches: Seq[(Expression, Expression)], | ||
elseValue: Option[Expression]) | ||
extends Expression with Serializable { | ||
|
||
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue | ||
|
||
|
@@ -142,16 +139,54 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E | |
} | ||
} | ||
|
||
def shouldCodegen: Boolean = { | ||
branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN | ||
override def toString: String = { | ||
val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString | ||
val elseCase = elseValue.map(" ELSE " + _).getOrElse("") | ||
"CASE" + cases + elseCase + " END" | ||
} | ||
|
||
override def sql: String = { | ||
val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString | ||
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") | ||
"CASE" + cases + elseCase + " END" | ||
} | ||
} | ||
|
||
|
||
/** | ||
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". | ||
* When a = true, returns b; when c = true, returns d; else returns e. | ||
* | ||
* @param branches seq of (branch condition, branch value) | ||
* @param elseValue optional value for the else branch | ||
*/ | ||
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") | ||
// scalastyle:on line.size.limit | ||
case class CaseWhen( | ||
val branches: Seq[(Expression, Expression)], | ||
val elseValue: Option[Expression] = None) | ||
extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just have a toCodegen function that creates CaseWhenCodegen? We can then remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be right. |
||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
super[CodegenFallback].doGenCode(ctx, ev) | ||
} | ||
} | ||
|
||
/** | ||
* CaseWhen expression used when code generation condition is satisfied. | ||
* OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen. | ||
* | ||
* @param branches seq of (branch condition, branch value) | ||
* @param elseValue optional value for the else branch | ||
*/ | ||
case class CaseWhenCodegen( | ||
val branches: Seq[(Expression, Expression)], | ||
val elseValue: Option[Expression] = None) | ||
extends CaseWhenBase(branches, elseValue) with Serializable { | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
if (!shouldCodegen) { | ||
// Fallback to interpreted mode if there are too many branches, as it may reach the | ||
// 64K limit (limit on bytecode size for a single function). | ||
return super[CodegenFallback].doGenCode(ctx, ev) | ||
} | ||
// Generate code that looks like: | ||
// | ||
// condA = ... | ||
|
@@ -202,26 +237,10 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E | |
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; | ||
$generatedCode""") | ||
} | ||
|
||
override def toString: String = { | ||
val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString | ||
val elseCase = elseValue.map(" ELSE " + _).getOrElse("") | ||
"CASE" + cases + elseCase + " END" | ||
} | ||
|
||
override def sql: String = { | ||
val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString | ||
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") | ||
"CASE" + cases + elseCase + " END" | ||
} | ||
} | ||
|
||
/** Factory methods for CaseWhen. */ | ||
object CaseWhen { | ||
|
||
// The maximum number of switches supported with codegen. | ||
val MAX_NUM_CASES_FOR_CODEGEN = 20 | ||
|
||
def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { | ||
CaseWhen(branches, Option(elseValue)) | ||
} | ||
|
@@ -242,6 +261,12 @@ object CaseWhen { | |
} | ||
} | ||
|
||
/** Factory methods for CaseWhenCodegen. */ | ||
object CaseWhenCodegen { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can remove this given the above comment |
||
def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhenCodegen = { | ||
CaseWhenCodegen(branches, Option(elseValue)) | ||
} | ||
} | ||
|
||
/** | ||
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.optimizer | ||
|
||
import org.apache.spark.sql.catalyst.dsl.plans._ | ||
import org.apache.spark.sql.catalyst.SimpleCatalystConf | ||
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.Literal._ | ||
import org.apache.spark.sql.catalyst.plans.PlanTest | ||
import org.apache.spark.sql.catalyst.plans.logical._ | ||
import org.apache.spark.sql.catalyst.rules._ | ||
|
||
|
||
class OptimizeCodegenSuite extends PlanTest { | ||
|
||
object Optimize extends RuleExecutor[LogicalPlan] { | ||
val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil | ||
} | ||
|
||
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { | ||
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze | ||
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) | ||
comparePlans(actual, correctAnswer) | ||
} | ||
|
||
test("Codegen only when the number of branches is small.") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you make sure you construct a few more test cases one with nested casewhen, and one with multiple case when in one operator, and one with multiple casewhen in different operators There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh. Sure. I'll add those testcases, too. |
||
assertEquivalent( | ||
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), | ||
CaseWhenCodegen(Seq((TrueLiteral, Literal(1))), Literal(2))) | ||
|
||
assertEquivalent( | ||
CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)), | ||
CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2))) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maxCaseBranchesForCodegen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for quick review. Sure. And also
maxCaseBranchesForCodegen
in SQLConf.scala.