Skip to content

Commit

Permalink
[SPARK-47563][SQL] Add map normalization on creation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Added normalization of map keys when they are put in `ArrayBasedMapBuilder`.

### Why are the changes needed?
As map keys need to be unique, we need to add normalization on floating point numbers and prevent the following case when building a map: `Map(0.0, -0.0)`.
This further unblocks GROUP BY statement for Map Types as per [this discussion](#45549 (comment)).

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
New UTs in `ArrayBasedMapBuilderSuite`

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #45721 from stevomitric/stevomitric/fix-map-dup.

Authored-by: Stevo Mitric <stevo.mitric@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
stevomitric authored and cloud-fan committed Mar 27, 2024
1 parent d326cb9 commit 87449c3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection.mutable

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -52,18 +53,25 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria

private val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY)

private lazy val keyNormalizer: Any => Any = keyType match {
case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER
case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER
case _ => identity
}

def put(key: Any, value: Any): Unit = {
if (key == null) {
throw QueryExecutionErrors.nullAsMapKeyNotAllowedError()
}

val index = keyToIndex.getOrDefault(key, -1)
val keyNormalized = keyNormalizer(key)
val index = keyToIndex.getOrDefault(keyNormalized, -1)
if (index == -1) {
if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw QueryExecutionErrors.exceedMapSizeLimitError(size)
}
keyToIndex.put(key, values.length)
keys.append(key)
keyToIndex.put(keyNormalized, values.length)
keys.append(keyNormalized)
values.append(value)
} else {
if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, DoubleType, IntegerType, StructType}
import org.apache.spark.unsafe.Platform

class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper {
Expand Down Expand Up @@ -60,6 +60,26 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper {
)
}

test("apply key normalization when creating") {
val builderDouble = new ArrayBasedMapBuilder(DoubleType, IntegerType)
builderDouble.put(-0.0, 1)
checkError(
exception = intercept[SparkRuntimeException](builderDouble.put(0.0, 2)),
errorClass = "DUPLICATED_MAP_KEY",
parameters = Map(
"key" -> "0.0",
"mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"")
)
}

test("successful map normalization on build") {
val builder = new ArrayBasedMapBuilder(DoubleType, IntegerType)
builder.put(-0.0, 1)
val map = builder.build()
assert(map.numElements() == 1)
assert(ArrayBasedMapData.toScalaMap(map) == Map(0.0 -> 1))
}

test("remove duplicated keys with last wins policy") {
withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType)
Expand Down

0 comments on commit 87449c3

Please sign in to comment.