Skip to content

Commit

Permalink
[SPARK-27097] Avoid embedding platform-dependent offsets literally in…
Browse files Browse the repository at this point in the history
… whole-stage generated code

## What changes were proposed in this pull request?

Spark SQL performs whole-stage code generation to speed up query execution. There are two steps to it:
- Java source code is generated from the physical query plan on the driver. A single version of the source code is generated from a query plan, and sent to all executors.
  - It's compiled to bytecode on the driver to catch compilation errors before sending to executors, but currently only the generated source code gets sent to the executors. The bytecode compilation is for fail-fast only.
- Executors receive the generated source code and compile to bytecode, then the query runs like a hand-written Java program.

In this model, there's an implicit assumption about the driver and executors being run on similar platforms. Some code paths accidentally embedded platform-dependent object layout information into the generated code, such as:
```java
Platform.putLong(buffer, /* offset */ 24, /* value */ 1);
```
This code expects a field to be at offset +24 of the `buffer` object, and sets a value to that field.
But whole-stage code generation generally uses platform-dependent information from the driver. If the object layout is significantly different on the driver and executors, the generated code can be reading/writing to wrong offsets on the executors, causing all kinds of data corruption.

One code pattern that leads to such problem is the use of `Platform.XXX` constants in generated code, e.g. `Platform.BYTE_ARRAY_OFFSET`.

Bad:
```scala
val baseOffset = Platform.BYTE_ARRAY_OFFSET
// codegen template:
s"Platform.putLong($buffer, $baseOffset, $value);"
```
This will embed the value of `Platform.BYTE_ARRAY_OFFSET` on the driver into the generated code.

Good:
```scala
val baseOffset = "Platform.BYTE_ARRAY_OFFSET"
// codegen template:
s"Platform.putLong($buffer, $baseOffset, $value);"
```
This will generate the offset symbolically -- `Platform.putLong(buffer, Platform.BYTE_ARRAY_OFFSET, value)`, which will be able to pick up the correct value on the executors.

Caveat: these offset constants are declared as runtime-initialized `static final` in Java, so they're not compile-time constants from the Java language's perspective. It does lead to a slightly increased size of the generated code, but this is necessary for correctness.

NOTE: there can be other patterns that generate platform-dependent code on the driver which is invalid on the executors. e.g. if the endianness is different between the driver and the executors, and if some generated code makes strong assumption about endianness, it would also be problematic.

## How was this patch tested?

Added a new test suite `WholeStageCodegenSparkSubmitSuite`. This test suite needs to set the driver's extraJavaOptions to force the driver and executor use different Java object layouts, so it's run as an actual SparkSubmit job.

Authored-by: Kris Mok <kris.mokdatabricks.com>

Closes #24031 from gatorsmile/cherrypickSPARK-27097.

Lead-authored-by: Kris Mok <kris.mok@databricks.com>
Co-authored-by: gatorsmile <gatorsmile@gmail.com>
Signed-off-by: DB Tsai <d_tsai@apple.com>
  • Loading branch information
2 people authored and dbtsai committed Mar 9, 2019
1 parent 326fc74 commit 57ae251
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen

private byte[] arr = new byte[1024 * 1024];
private Object baseObject = arr;
private final long baseOffset = Platform.BYTE_ARRAY_OFFSET;
private final TaskContext taskContext = TaskContext.get();

public UnsafeSorterSpillReader(
Expand Down Expand Up @@ -125,7 +124,7 @@ public Object getBaseObject() {

@Override
public long getBaseOffset() {
return baseOffset;
return Platform.BYTE_ARRAY_OFFSET;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U

def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
val ctx = new CodegenContext
val offset = Platform.BYTE_ARRAY_OFFSET
val offset = "Platform.BYTE_ARRAY_OFFSET"
val getLong = "Platform.getLong"
val putLong = "Platform.putLong"

Expand Down Expand Up @@ -92,7 +92,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})"
}
}
s"$putLong(buf, ${offset + i * 8}, $bits);\n"
s"$putLong(buf, $offset + ${i * 8}, $bits);\n"
}

val copyBitsets = ctx.splitExpressions(
Expand All @@ -102,12 +102,12 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
("java.lang.Object", "obj2") :: ("long", "offset2") :: Nil)

// --------------------- copy fixed length portion from row 1 ----------------------- //
var cursor = offset + outputBitsetWords * 8
var cursor = outputBitsetWords * 8
val copyFixedLengthRow1 = s"""
|// Copy fixed length data for row1
|Platform.copyMemory(
| obj1, offset1 + ${bitset1Words * 8},
| buf, $cursor,
| buf, $offset + $cursor,
| ${schema1.size * 8});
""".stripMargin
cursor += schema1.size * 8
Expand All @@ -117,7 +117,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|// Copy fixed length data for row2
|Platform.copyMemory(
| obj2, offset2 + ${bitset2Words * 8},
| buf, $cursor,
| buf, $offset + $cursor,
| ${schema2.size * 8});
""".stripMargin
cursor += schema2.size * 8
Expand All @@ -129,7 +129,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1;
|Platform.copyMemory(
| obj1, offset1 + ${(bitset1Words + schema1.size) * 8},
| buf, $cursor,
| buf, $offset + $cursor,
| numBytesVariableRow1);
""".stripMargin

Expand All @@ -140,7 +140,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
|long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2;
|Platform.copyMemory(
| obj2, offset2 + ${(bitset2Words + schema2.size) * 8},
| buf, $cursor + numBytesVariableRow1,
| buf, $offset + $cursor + numBytesVariableRow1,
| numBytesVariableRow2);
""".stripMargin

Expand All @@ -161,7 +161,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
} else {
s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
}
val cursor = offset + outputBitsetWords * 8 + i * 8
val cursor = outputBitsetWords * 8 + i * 8
// UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's
// output as a de-facto specification for the internal layout of data.
//
Expand Down Expand Up @@ -198,9 +198,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
// Thus it is safe to perform `existingOffset != 0` checks here in the place of
// more expensive null-bit checks.
s"""
|existingOffset = $getLong(buf, $cursor);
|existingOffset = $getLong(buf, $offset + $cursor);
|if (existingOffset != 0) {
| $putLong(buf, $cursor, existingOffset + ($shift << 32));
| $putLong(buf, $offset + $cursor, existingOffset + ($shift << 32));
|}
""".stripMargin
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
val z = ctx.freshName("z")
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"

val baseOffset = Platform.BYTE_ARRAY_OFFSET
val baseOffset = "Platform.BYTE_ARRAY_OFFSET"
val wordSize = UnsafeRow.WORD_SIZE
val structSizeAsLong = s"${structSize}L"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers}
import org.scalatest.concurrent.TimeLimits

import org.apache.spark.{SparkFunSuite, TestUtils}
import org.apache.spark.deploy.SparkSubmitSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{LocalSparkSession, QueryTest, Row, SparkSession}
import org.apache.spark.sql.functions.{array, col, count, lit}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.ResetSystemProperties

// Due to the need to set driver's extraJavaOptions, this test needs to use actual SparkSubmit.
class WholeStageCodegenSparkSubmitSuite extends SparkFunSuite
with Matchers
with BeforeAndAfterEach
with ResetSystemProperties {

test("Generated code on driver should not embed platform-specific constant") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)

// HotSpot JVM specific: Set up a local cluster with the driver/executor using mismatched
// settings of UseCompressedOops JVM option.
val argsForSparkSubmit = Seq(
"--class", WholeStageCodegenSparkSubmitSuite.getClass.getName.stripSuffix("$"),
"--master", "local-cluster[1,1,1024]",
"--driver-memory", "1g",
"--conf", "spark.ui.enabled=false",
"--conf", "spark.master.rest.enabled=false",
"--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedOops",
"--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedOops",
unusedJar.toString)
SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..")
}
}

object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging {

var spark: SparkSession = _

def main(args: Array[String]): Unit = {
TestUtils.configTestLog4j("INFO")

spark = SparkSession.builder().getOrCreate()

// Make sure the test is run where the driver and the executors uses different object layouts
val driverArrayHeaderSize = Platform.BYTE_ARRAY_OFFSET
val executorArrayHeaderSize =
spark.sparkContext.range(0, 1).map(_ => Platform.BYTE_ARRAY_OFFSET).collect.head.toInt
assert(driverArrayHeaderSize > executorArrayHeaderSize)

val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v")
.groupBy(array(col("v"))).agg(count(col("*")))
val plan = df.queryExecution.executedPlan
assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)

val expectedAnswer =
Row(Array(0), 7178) ::
Row(Array(1), 7178) ::
Row(Array(2), 7178) ::
Row(Array(3), 7177) ::
Row(Array(4), 7177) ::
Row(Array(5), 7177) ::
Row(Array(6), 7177) ::
Row(Array(7), 7177) ::
Row(Array(8), 7177) ::
Row(Array(9), 7177) :: Nil
val result = df.collect
QueryTest.sameRows(result.toSeq, expectedAnswer) match {
case Some(errMsg) => fail(errMsg)
case _ =>
}
}
}

0 comments on commit 57ae251

Please sign in to comment.