From 23c1f5be4cf640c01fd9ff0581eb8269f2800bf7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 29 Jun 2016 20:16:31 +0800 Subject: [PATCH] add null check for key when create map data in encoder --- .../apache/spark/sql/catalyst/util/ArrayBasedMapData.scala | 5 +++++ .../sql/catalyst/encoders/ExpressionEncoderSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 4449da13c083c..2e0ac8a6e6d54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -19,6 +19,11 @@ package org.apache.spark.sql.catalyst.util class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData { require(keyArray.numElements() == valueArray.numElements()) + for (i <- 0 until keyArray.numElements()) { + if (keyArray.isNullAt(i)) { + throw new RuntimeException("Cannot use null as map key!") + } + } override def numElements(): Int = keyArray.numElements() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index a1f9259f139ed..4df9062018995 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -328,6 +328,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } } + test("null check for map key") { + val encoder = ExpressionEncoder[Map[String, Int]]() + val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 2)))) + assert(e.getMessage.contains("Cannot use null as map key")) + } + private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = {