Skip to content

Commit

Permalink
[SPARK-14577][SQL] Add spark.sql.codegen.maxCaseBranches config option
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

We currently disable codegen for `CaseWhen` if the number of branches is greater than 20 (in CaseWhen.MAX_NUM_CASES_FOR_CODEGEN). It would be better if this value is a non-public config defined in SQLConf.

## How was this patch tested?

Pass the Jenkins tests (including a new testcase `Support spark.sql.codegen.maxCaseBranches option`)

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #12353 from dongjoon-hyun/SPARK-14577.
  • Loading branch information
dongjoon-hyun authored and cloud-fan committed Apr 19, 2016
1 parent 74fe235 commit 3d46d79
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 35 deletions.
Expand Up @@ -29,6 +29,7 @@ trait CatalystConf {
def groupByOrdinal: Boolean

def optimizerMaxIterations: Int
def maxCaseBranchesForCodegen: Int

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
Expand All @@ -45,6 +46,7 @@ case class SimpleCatalystConf(
caseSensitiveAnalysis: Boolean,
orderByOrdinal: Boolean = true,
groupByOrdinal: Boolean = true,
optimizerMaxIterations: Int = 100)
optimizerMaxIterations: Int = 100,
maxCaseBranchesForCodegen: Int = 20)
extends CatalystConf {
}
Expand Up @@ -81,18 +81,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}

/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* When a = true, returns b; when c = true, returns d; else returns e.
* Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
// scalastyle:on line.size.limit
case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
extends Expression with CodegenFallback {
abstract class CaseWhenBase(
branches: Seq[(Expression, Expression)],
elseValue: Option[Expression])
extends Expression with Serializable {

override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue

Expand Down Expand Up @@ -142,16 +139,58 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
}
}

def shouldCodegen: Boolean = {
branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN
override def toString: String = {
val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
"CASE" + cases + elseCase + " END"
}

override def sql: String = {
val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
"CASE" + cases + elseCase + " END"
}
}


/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* When a = true, returns b; when c = true, returns d; else returns e.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
// scalastyle:on line.size.limit
case class CaseWhen(
val branches: Seq[(Expression, Expression)],
val elseValue: Option[Expression] = None)
extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
super[CodegenFallback].doGenCode(ctx, ev)
}

def toCodegen(): CaseWhenCodegen = {
CaseWhenCodegen(branches, elseValue)
}
}

/**
* CaseWhen expression used when code generation condition is satisfied.
* OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
case class CaseWhenCodegen(
val branches: Seq[(Expression, Expression)],
val elseValue: Option[Expression] = None)
extends CaseWhenBase(branches, elseValue) with Serializable {

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (!shouldCodegen) {
// Fallback to interpreted mode if there are too many branches, as it may reach the
// 64K limit (limit on bytecode size for a single function).
return super[CodegenFallback].doGenCode(ctx, ev)
}
// Generate code that looks like:
//
// condA = ...
Expand Down Expand Up @@ -202,26 +241,10 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$generatedCode""")
}

override def toString: String = {
val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
"CASE" + cases + elseCase + " END"
}

override def sql: String = {
val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
"CASE" + cases + elseCase + " END"
}
}

/** Factory methods for CaseWhen. */
object CaseWhen {

// The maximum number of switches supported with codegen.
val MAX_NUM_CASES_FOR_CODEGEN = 20

def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
CaseWhen(branches, Option(elseValue))
}
Expand All @@ -242,7 +265,6 @@ object CaseWhen {
}
}


/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
* When a = b, returns c; when a = d, returns e; else returns f.
Expand Down
Expand Up @@ -104,7 +104,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
Batch("Subquery", Once,
OptimizeSubqueries) :: Nil
OptimizeSubqueries) ::
Batch("OptimizeCodegen", Once,
OptimizeCodegen(conf)) :: Nil
}

/**
Expand Down Expand Up @@ -863,6 +865,16 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
}
}

/**
* Optimizes expressions by replacing according to CodeGen configuration.
*/
case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen =>
e.toCodegen()
}
}

/**
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
Expand Down
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class OptimizeCodegenSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil
}

protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
comparePlans(actual, correctAnswer)
}

test("Codegen only when the number of branches is small.") {
assertEquivalent(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())

assertEquivalent(
CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)),
CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)))
}

test("Nested CaseWhen Codegen.") {
assertEquivalent(
CaseWhen(
Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))),
CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
CaseWhen(
Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))),
CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
}

test("Multiple CaseWhen in one operator.") {
val plan = OneRowRelation
.select(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
val correctAnswer = OneRowRelation
.select(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, correctAnswer)
}

test("Multiple CaseWhen in different operators") {
val plan = OneRowRelation
.select(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
.where(
LessThan(
CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
).analyze
val correctAnswer = OneRowRelation
.select(
CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
.where(
LessThan(
CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
).analyze
val optimized = Optimize.execute(plan)
comparePlans(optimized, correctAnswer)
}
}
Expand Up @@ -429,7 +429,6 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {

private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
case e: CaseWhen => e.shouldCodegen
// CodegenFallback requires the input to be an InternalRow
case e: CodegenFallback => false
case _ => true
Expand Down
Expand Up @@ -402,6 +402,12 @@ object SQLConf {
.intConf
.createWithDefault(200)

val MAX_CASES_BRANCHES = SQLConfigBuilder("spark.sql.codegen.maxCaseBranches")
.internal()
.doc("The maximum number of switches supported with codegen.")
.intConf
.createWithDefault(20)

val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
.longConf
Expand Down Expand Up @@ -529,6 +535,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)

def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)

def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)

def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW)
Expand Down

0 comments on commit 3d46d79

Please sign in to comment.