Skip to content

Commit

Permalink
[SPARK-28361][SQL][TEST] Test equality of generated code with id in c…
Browse files Browse the repository at this point in the history
…lass name

A code gen test in WholeStageCodeGenSuite was flaky because it used the codegen metrics class to test if the generated code for equivalent plans was identical under a particular flag. This patch switches the test to compare the generated code directly.

N/A

Closes #25131 from gatorsmile/WholeStageCodegenSuite.

Authored-by: gatorsmile <gatorsmile@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
gatorsmile authored and dongjoon-hyun committed Jul 12, 2019
1 parent aa41dce commit 60b89cf
Showing 1 changed file with 19 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.sql.execution

import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.{QueryTest, Row, SaveMode}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
Expand Down Expand Up @@ -145,10 +143,10 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
.select("int")

val plan = df.queryExecution.executedPlan
assert(!plan.find(p =>
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].child.children(0)
.isInstanceOf[SortMergeJoinExec]).isDefined)
.isInstanceOf[SortMergeJoinExec]).isEmpty)
assert(df.collect() === Array(Row(1), Row(2)))
}
}
Expand Down Expand Up @@ -181,6 +179,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2
}

def genCode(ds: Dataset[_]): Seq[CodeAndComment] = {
val plan = ds.queryExecution.executedPlan
val wholeStageCodeGenExecs = plan.collect { case p: WholeStageCodegenExec => p }
assert(wholeStageCodeGenExecs.nonEmpty, "WholeStageCodegenExec is expected")
wholeStageCodeGenExecs.map(_.doCodeGen()._2)
}

ignore("SPARK-21871 check if we can get large code size when compiling too long functions") {
val codeWithShortFunctions = genGroupByCode(3)
val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions)
Expand Down Expand Up @@ -241,9 +246,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
val df = spark.range(100)
val join = df.join(df, "id")
val plan = join.queryExecution.executedPlan
assert(!plan.find(p =>
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isDefined,
p.asInstanceOf[WholeStageCodegenExec].codegenStageId == 0).isEmpty,
"codegen stage IDs should be preserved through ReuseExchange")
checkAnswer(join, df.toDF)
}
Expand All @@ -253,18 +258,13 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
import testImplicits._

withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") {
val bytecodeSizeHisto = CodegenMetrics.METRIC_COMPILATION_TIME

// the same query run twice should hit the codegen cache
spark.range(3).select('id + 2).collect
val after1 = bytecodeSizeHisto.getCount
spark.range(3).select('id + 2).collect
val after2 = bytecodeSizeHisto.getCount // same query shape as above, deliberately
// bytecodeSizeHisto's count is always monotonically increasing if new compilation to
// bytecode had occurred. If the count stayed the same that means we've got a cache hit.
assert(after1 == after2, "Should hit codegen cache. No new compilation to bytecode expected")

// a different query can result in codegen cache miss, that's by design
// the same query run twice should produce identical code, which would imply a hit in
// the generated code cache.
val ds1 = spark.range(3).select('id + 2)
val code1 = genCode(ds1)
val ds2 = spark.range(3).select('id + 2)
val code2 = genCode(ds2) // same query shape as above, deliberately
assert(code1 == code2, "Should produce same code")
}
}

Expand Down

0 comments on commit 60b89cf

Please sign in to comment.