Skip to content

Commit

Permalink
[SPARK-22984] Fix incorrect bitmap copying and offset adjustment in G…
Browse files Browse the repository at this point in the history
…enerateUnsafeRowJoiner

## What changes were proposed in this pull request?

This PR fixes a longstanding correctness bug in `GenerateUnsafeRowJoiner`. This class was introduced in #7821 (July 2015 / Spark 1.5.0+) and is used to combine pairs of UnsafeRows in TungstenAggregationIterator, CartesianProductExec, and AppendColumns.

### Bugs fixed by this patch

1. **Incorrect combining of null-tracking bitmaps**: when concatenating two UnsafeRows, the implementation "Concatenate the two bitsets together into a single one, taking padding into account". If one row has no columns then it has a bitset size of 0, but the code was incorrectly assuming that if the left row had a non-zero number of fields then the right row would also have at least one field, so it was copying invalid bytes and and treating them as part of the bitset. I'm not sure whether this bug was also present in the original implementation or whether it was introduced in #7892 (which fixed another bug in this code).
2. **Incorrect updating of data offsets for null variable-length fields**: after updating the bitsets and copying fixed-length and variable-length data, we need to perform adjustments to the offsets pointing the start of variable length fields's data. The existing code was _conditionally_ adding a fixed offset to correct for the new length of the combined row, but it is unsafe to do this if the variable-length field has a null value: we always represent nulls by storing `0` in the fixed-length slot, but this code was incorrectly incrementing those values. This bug was present since the original version of `GenerateUnsafeRowJoiner`.

### Why this bug remained latent for so long

The PR which introduced `GenerateUnsafeRowJoiner` features several randomized tests, including tests of the cases where one side of the join has no fields and where string-valued fields are null. However, the existing assertions were too weak to uncover this bug:

- If a null field has a non-zero value in its fixed-length data slot then this will not cause problems for field accesses because the null-tracking bitmap should still be correct and we will not try to use the incorrect offset for anything.
- If the null tracking bitmap is corrupted by joining against a row with no fields then the corruption occurs in field numbers past the actual field numbers contained in the row. Thus valid `isNullAt()` calls will not read the incorrectly-set bits.

The existing `GenerateUnsafeRowJoinerSuite` tests only exercised `.get()` and `isNullAt()`, but didn't actually check the UnsafeRows for bit-for-bit equality, preventing these bugs from failing assertions. It turns out that there was even a [GenerateUnsafeRowJoinerBitsetSuite](https://github.com/apache/spark/blob/03377d2522776267a07b7d6ae9bddf79a4e0f516/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala) but it looks like it also didn't catch this problem because it only tested the bitsets in an end-to-end fashion by accessing them through the `UnsafeRow` interface instead of actually comparing the bitsets' bytes.

### Impact of these bugs

- This bug will cause `equals()` and `hashCode()` to be incorrect for these rows, which will be problematic in case`GenerateUnsafeRowJoiner`'s results are used as join or grouping keys.
- Chained / repeated invocations of `GenerateUnsafeRowJoiner` may result in reads from invalid null bitmap positions causing fields to incorrectly become NULL (see the end-to-end example below).
  - It looks like this generally only happens in `CartesianProductExec`, which our query optimizer often avoids executing (usually we try to plan a `BroadcastNestedLoopJoin` instead).

### End-to-end test case demonstrating the problem

The following query demonstrates how this bug may result in incorrect query results:

```sql
set spark.sql.autoBroadcastJoinThreshold=-1; -- Needed to trigger CartesianProductExec

create table a as select * from values 1;
create table b as select * from values 2;

SELECT
  t3.col1,
  t1.col1
FROM a t1
CROSS JOIN b t2
CROSS JOIN b t3
```

This should return `(2, 1)` but instead was returning `(null, 1)`.

Column pruning ends up trimming off all columns from `t2`, so when `t2` joins with another table this triggers the bitmap-copying bug. This incorrect bitmap is subsequently copied again when performing the final join, causing the final output to have an incorrectly-set null bit for the first field.

## How was this patch tested?

Strengthened the assertions in existing tests in GenerateUnsafeRowJoinerSuite. Also verified that the end-to-end test case which uncovered this now passes.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #20181 from JoshRosen/SPARK-22984-fix-generate-unsaferow-joiner-bitmap-bugs.
  • Loading branch information
JoshRosen authored and cloud-fan committed Jan 9, 2018
1 parent 849043c commit f20131d
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U

// --------------------- copy bitset from row 1 and row 2 --------------------------- //
val copyBitset = Seq.tabulate(outputBitsetWords) { i =>
val bits = if (bitset1Remainder > 0) {
val bits = if (bitset1Remainder > 0 && bitset2Words != 0) {
if (i < bitset1Words - 1) {
s"$getLong(obj1, offset1 + ${i * 8})"
} else if (i == bitset1Words - 1) {
Expand Down Expand Up @@ -152,22 +152,65 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
} else {
// Number of bytes to increase for the offset. Note that since in UnsafeRow we store the
// offset in the upper 32 bit of the words, we can just shift the offset to the left by
// 32 and increment that amount in place.
// 32 and increment that amount in place. However, we need to handle the important special
// case of a null field, in which case the offset should be zero and should not have a
// shift added to it.
val shift =
if (i < schema1.size) {
s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
} else {
s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
}
val cursor = offset + outputBitsetWords * 8 + i * 8
s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n"
// UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's
// output as a de-facto specification for the internal layout of data.
//
// Null-valued fields will always have a data offset of 0 because
// UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 to in field's
// position in the fixed-length section of the row. As a result, we must NOT add
// `shift` to the offset for null fields.
//
// We could perform a null-check here by inspecting the null-tracking bitmap, but doing
// so could be expensive and will add significant bloat to the generated code. Instead,
// we'll rely on the invariant "stored offset == 0 for variable-length data type implies
// that the field's value is null."
//
// To establish that this invariant holds, we'll prove that a non-null field can never
// have a stored offset of 0. There are two cases to consider:
//
// 1. The non-null field's data is of non-zero length: reading this field's value
// must read data from the variable-length section of the row, so the stored offset
// will actually be used in address calculation and must be correct. The offsets
// count bytes from the start of the UnsafeRow so these offsets will always be
// non-zero because the storage of the offsets themselves takes up space at the
// start of the row.
// 2. The non-null field's data is of zero length (i.e. its data is empty). In this
// case, we have to worry about the possibility that an arbitrary offset value was
// stored because we never actually read any bytes using this offset and therefore
// would not crash if it was incorrect. The variable-sized data writing paths in
// UnsafeRowWriter unconditionally calls setOffsetAndSize(ordinal, numBytes) with
// no special handling for the case where `numBytes == 0`. Internally,
// setOffsetAndSize computes the offset without taking the size into account. Thus
// the stored offset is the same non-zero offset that would be used if the field's
// dataSize was non-zero (and in (1) above we've shown that case behaves as we
// expect).
//
// Thus it is safe to perform `existingOffset != 0` checks here in the place of
// more expensive null-bit checks.
s"""
|existingOffset = $getLong(buf, $cursor);
|if (existingOffset != 0) {
| $putLong(buf, $cursor, existingOffset + ($shift << 32));
|}
""".stripMargin
}
}

val updateOffsets = ctx.splitExpressions(
expressions = updateOffset,
funcName = "copyBitsetFunc",
arguments = ("long", "numBytesVariableRow1") :: Nil)
arguments = ("long", "numBytesVariableRow1") :: Nil,
makeSplitFunction = (s: String) => "long existingOffset;\n" + s)

// ------------------------ Finally, put everything together --------------------------- //
val codeBody = s"""
Expand Down Expand Up @@ -200,6 +243,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
| $copyFixedLengthRow2
| $copyVariableLengthRow1
| $copyVariableLengthRow2
| long existingOffset;
| $updateOffsets
|
| out.pointTo(buf, sizeInBytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Test suite for [[GenerateUnsafeRowJoiner]].
Expand All @@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
testConcat(64, 64, fixed)
}

test("rows with all empty strings") {
val schema = StructType(Seq(
StructField("f1", StringType), StructField("f2", StringType)))
val row: UnsafeRow = UnsafeProjection.create(schema).apply(
InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8))
testConcat(schema, row, schema, row)
}

test("rows with all empty int arrays") {
val schema = StructType(Seq(
StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType))))
val emptyIntArray =
ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0)
val row: UnsafeRow = UnsafeProjection.create(schema).apply(
InternalRow(emptyIntArray, emptyIntArray))
testConcat(schema, row, schema, row)
}

test("alternating empty and non-empty strings") {
val schema = StructType(Seq(
StructField("f1", StringType), StructField("f2", StringType)))
val row: UnsafeRow = UnsafeProjection.create(schema).apply(
InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo")))
testConcat(schema, row, schema, row)
}

test("randomized fix width types") {
for (i <- 0 until 20) {
testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
Expand Down Expand Up @@ -94,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply()
val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])
testConcat(schema1, row1, schema2, row2)
}

private def testConcat(
schema1: StructType,
row1: UnsafeRow,
schema2: StructType,
row2: UnsafeRow) {

// Run the joiner.
val mergedSchema = StructType(schema1 ++ schema2)
val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
val output = concater.join(row1, row2)
val output: UnsafeRow = concater.join(row1, row2)

// We'll also compare to an UnsafeRow produced with JoinedRow + UnsafeProjection. This ensures
// that unused space in the row (e.g. leftover bits in the null-tracking bitmap) is written
// correctly.
val expectedOutput: UnsafeRow = {
val joinedRowProjection = UnsafeProjection.create(mergedSchema)
val joined = new JoinedRow()
joinedRowProjection.apply(joined.apply(row1, row2))
}

// Test everything equals ...
for (i <- mergedSchema.indices) {
val dataType = mergedSchema(i).dataType
if (i < schema1.size) {
assert(output.isNullAt(i) === row1.isNullAt(i))
if (!output.isNullAt(i)) {
assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType))
assert(output.get(i, dataType) === row1.get(i, dataType))
assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
}
} else {
assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
if (!output.isNullAt(i)) {
assert(output.get(i, mergedSchema(i).dataType) ===
row2.get(i - schema1.size, mergedSchema(i).dataType))
assert(output.get(i, dataType) === row2.get(i - schema1.size, dataType))
assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
}
}
}


assert(
expectedOutput.getSizeInBytes == output.getSizeInBytes,
"output isn't same size in bytes as slow path")

// Compare the UnsafeRows byte-by-byte so that we can print more useful debug information in
// case this assertion fails:
val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]]
.take(output.getSizeInBytes)
val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]]
.take(expectedOutput.getSizeInBytes)

val bitsetWidth = UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields())
val actualBitset = actualBytes.take(bitsetWidth)
val expectedBitset = expectedBytes.take(bitsetWidth)
assert(actualBitset === expectedBitset, "bitsets were not equal")

val fixedLengthSize = expectedOutput.numFields() * 8
val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize)
val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize)
if (actualFixedLength !== expectedFixedLength) {
actualFixedLength.grouped(8)
.zip(expectedFixedLength.grouped(8))
.zip(mergedSchema.fields.toIterator)
.foreach {
case ((actual, expected), field) =>
assert(actual === expected, s"Fixed length sections are not equal for field $field")
}
fail("Fixed length sections were not equal")
}

val variableLengthStart = bitsetWidth + fixedLengthSize
val actualVariableLength = actualBytes.drop(variableLengthStart)
val expectedVariableLength = expectedBytes.drop(variableLengthStart)
assert(actualVariableLength === expectedVariableLength, "fixed length sections were not equal")

assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were not equal")
}

}

0 comments on commit f20131d

Please sign in to comment.