diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 0e8bb84ee5d81..4595ea049ef70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -705,6 +705,13 @@ trait HashJoin extends JoinCodegenSupport { } object HashJoin extends CastSupport with SQLConfHelper { + + private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = { + // TODO: support BooleanType, DateType and TimestampType + keys.forall(_.dataType.isInstanceOf[IntegralType]) && + keys.map(_.dataType.defaultSize).sum <= 8 + } + /** * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * @@ -712,9 +719,7 @@ object HashJoin extends CastSupport with SQLConfHelper { */ def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { assert(keys.nonEmpty) - // TODO: support BooleanType, DateType and TimestampType - if (keys.exists(!_.dataType.isInstanceOf[IntegralType]) - || keys.map(_.dataType.defaultSize).sum > 8) { + if (!canRewriteAsLongType(keys)) { return keys } @@ -736,18 +741,28 @@ object HashJoin extends CastSupport with SQLConfHelper { * determine the number of bits to shift */ def extractKeyExprAt(keys: Seq[Expression], index: Int): Expression = { + assert(canRewriteAsLongType(keys)) // jump over keys that have a higher index value than the required key if (keys.size == 1) { assert(index == 0) - cast(BoundReference(0, LongType, nullable = false), keys(index).dataType) + Cast( + child = BoundReference(0, LongType, nullable = false), + dataType = keys(index).dataType, + timeZoneId = Option(conf.sessionLocalTimeZone), + ansiEnabled = false) } else { val shiftedBits = keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1 // build the schema for unpacking the required key - cast(BitwiseAnd( + val castChild = BitwiseAnd( ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)), - Literal(mask)), keys(index).dataType) + Literal(mask)) + Cast( + child = castChild, + dataType = keys(index).dataType, + timeZoneId = Option(conf.sessionLocalTimeZone), + ansiEnabled = false) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index b8ffc47d6ec3c..d5b7ed6c275f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.map.BytesToBytesMap @@ -610,14 +611,19 @@ class HashedRelationSuite extends SharedSparkSession { val keys = Seq(BoundReference(0, ByteType, false), BoundReference(1, IntegerType, false), BoundReference(2, ShortType, false)) - val packed = HashJoin.rewriteKeyExpr(keys) - val unsafeProj = UnsafeProjection.create(packed) - val packedKeys = unsafeProj(row) - - Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i, dt) => - val key = HashJoin.extractKeyExprAt(keys, i) - val proj = UnsafeProjection.create(key) - assert(proj(packedKeys).get(0, dt) == -i - 1) + // Rewrite and exacting key expressions should not cause exception when ANSI mode is on. + Seq("false", "true").foreach { ansiEnabled => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) { + val packed = HashJoin.rewriteKeyExpr(keys) + val unsafeProj = UnsafeProjection.create(packed) + val packedKeys = unsafeProj(row) + + Seq((0, ByteType), (1, IntegerType), (2, ShortType)).foreach { case (i, dt) => + val key = HashJoin.extractKeyExprAt(keys, i) + val proj = UnsafeProjection.create(key) + assert(proj(packedKeys).get(0, dt) == -i - 1) + } + } } }