Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-38325][SQL] ANSI mode: avoid potential runtime error in HashJoin.extractKeyExprAt() #35659

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -739,15 +739,24 @@ object HashJoin extends CastSupport with SQLConfHelper {
// jump over keys that have a higher index value than the required key
if (keys.size == 1) {
assert(index == 0)
cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
Cast(
child = BoundReference(0, LongType, nullable = false),
dataType = keys(index).dataType,
timeZoneId = Option(conf.sessionLocalTimeZone),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that rewriteKeyExpr() above does not support timezone-related data type yet (Line 716). Shall we add an assertion for keys data type to be IntegralType and sum of key sizes no larger that 8 bytes in this method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, added assertion.

Copy link
Member Author

@gengliangwang gengliangwang Feb 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am keeping the timeZoneId here since we may support DateType/TimestampType in the future

ansiEnabled = false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we do expect to get null value if overflow happens?

cc @c21

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default cast returns the lower bits when an overflow happens.
E.g.

> select cast(1025L as byte);
1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry just for my understanding (without running the unit test by myself): why we expect overflow happens here? Shouldn't rewriteKeyExpr() above already guarantees that we only rewrite keys to long when all keys fit into a long type?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@c21 this PR is to avoid exceptions on overflow.
If you run HashedRelationSuite with ANSI mode on, there will be an overflow error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gengliangwang - got it, thanks for explanation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see, so this algorithm expects the equivalent of java code byte b = (byte) long_value, which is the same as the non-ansi Cast behavior.

} else {
val shiftedBits =
keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum
val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1
// build the schema for unpacking the required key
cast(BitwiseAnd(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check cast() used in other places to see if they have the same issue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I checked and the other usages looks fine.

val castChild = BitwiseAnd(
ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)),
Literal(mask)), keys(index).dataType)
Literal(mask))
Cast(
child = castChild,
dataType = keys(index).dataType,
timeZoneId = Option(conf.sessionLocalTimeZone),
ansiEnabled = false)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.map.BytesToBytesMap
Expand Down Expand Up @@ -610,14 +611,19 @@ class HashedRelationSuite extends SharedSparkSession {
val keys = Seq(BoundReference(0, ByteType, false),
BoundReference(1, IntegerType, false),
BoundReference(2, ShortType, false))
val packed = HashJoin.rewriteKeyExpr(keys)
val unsafeProj = UnsafeProjection.create(packed)
val packedKeys = unsafeProj(row)

Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i, dt) =>
val key = HashJoin.extractKeyExprAt(keys, i)
val proj = UnsafeProjection.create(key)
assert(proj(packedKeys).get(0, dt) == -i - 1)
// Rewrite and exacting key expressions should not cause exception when ANSI mode is on.
Seq("false", "true").foreach { ansiEnabled =>
withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) {
val packed = HashJoin.rewriteKeyExpr(keys)
val unsafeProj = UnsafeProjection.create(packed)
val packedKeys = unsafeProj(row)

Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i, dt) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is it possible to add some overflow values test as well? Non-blocking comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case can already cause overflow on unpacking.

val key = HashJoin.extractKeyExprAt(keys, i)
val proj = UnsafeProjection.create(key)
assert(proj(packedKeys).get(0, dt) == -i - 1)
}
}
}
}

Expand Down