Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23933][SQL] Add map_from_arrays function #21258

Closed
wants to merge 13 commits into from
19 changes: 19 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,25 @@ def create_map(*cols):
return Column(jc)


@since(2.4)
def map_from_arrays(col1, col2):
"""Creates a new map from two arrays.

:param col1: name of column containing a set of keys. All elements should not be null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and duplicated?

:param col2: name of column containing a set of values

>>> df = spark.createDataFrame([([2, 5], ['a', 'b'])], ['k', 'v'])
>>> df.select(map_from_arrays(df.k, df.v).alias("map")).show()
+----------------+
| map|
+----------------+
|[2 -> a, 5 -> b]|
+----------------+
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.map_from_arrays(_to_java_column(col1), _to_java_column(col2)))


@since(1.4)
def array(*cols):
"""Creates a new array column.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ object FunctionRegistry {
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[ElementAt]("element_at"),
expression[MapFromArrays]("map_from_arrays"),
expression[MapKeys]("map_keys"),
expression[MapValues]("map_values"),
expression[MapEntries]("map_entries"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
Expand Down Expand Up @@ -236,6 +236,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
override def prettyName: String = "map"
}

/**
* Returns a catalyst Map containing the two arrays in children expressions as keys and values.
*/
@ExpressionDescription(
usage = """
_FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements
in keys should not be null""",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and duplicated.

examples = """
Examples:
> SELECT _FUNC_([1.0, 3.0], ['2', '4']);
{1.0:"2",3.0:"4"}
""", since = "2.4.0")
case class MapFromArrays(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType)

override def dataType: DataType = {
MapType(
keyType = left.dataType.asInstanceOf[ArrayType].elementType,
valueType = right.dataType.asInstanceOf[ArrayType].elementType,
valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull)
}

override def nullSafeEval(keyArray: Any, valueArray: Any): Any = {
val keyArrayData = keyArray.asInstanceOf[ArrayData]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you detect duplicities first?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please let us know where this specification is described or is derived from? It is not written here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although it's not specified, duplicated key can lead to non-determinism of returned values in future. Currently, GetMapValueUtil.getValueEval returns a value for the first key in the map, but there is TODO to change O(n) algorithm. So I'm wondering how it would behave if some hashing was introduced.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. According to current Spark implementation, for example, CreateMap allows us to have duplicated key.
It would be good to discuss such a behavior change in another PR. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we don't have to change it now. But I would like to agree on a consistent approach for the new functions, since this is also related to SPARK-23934 and SPARK-23936.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to err on the safe side here. CreateMap should be fixed IMO.

val valueArrayData = valueArray.asInstanceOf[ArrayData]
if (keyArrayData.numElements != valueArrayData.numElements) {
throw new RuntimeException("The given two arrays should have the same length")
}
val leftArrayType = left.dataType.asInstanceOf[ArrayType]
if (leftArrayType.containsNull) {
var i = 0
while (i < keyArrayData.numElements) {
if (keyArrayData.isNullAt(i)) {
throw new RuntimeException("Cannot use null as map key!")
}
i += 1
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use loop to null-check without converting to object array?

}
new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy())
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => {
val arrayBasedMapData = classOf[ArrayBasedMapData].getName
val leftArrayType = left.dataType.asInstanceOf[ArrayType]
val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else {
val i = ctx.freshName("i")
s"""
|for (int $i = 0; $i < $keyArrayData.numElements(); $i++) {
| if ($keyArrayData.isNullAt($i)) {
| throw new RuntimeException("Cannot use null as map key!");
| }
|}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can null-check without converting to object array.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I realized we have to evaluate each element as CreateMap does. I think that we have to update eval and codegen.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, but I couldn't get it. I might miss something, but I thought we can simply do like:

for (int $i = 0; $i < $keyArrayData.numElements(); $i++) {
  if ($keyArrayData.isNullAt($i)) {
    throw new RuntimeException("Cannot use null as map key!");
  }
}

Doesn't this work?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code should work if we evaluate each element to make isNullAt() valid.

I think that my mistake is not to currently evaluate each element in keyArrayData and valueArrayData.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. An array has been evaluated.

""".stripMargin
}
s"""
|if ($keyArrayData.numElements() != $valueArrayData.numElements()) {
| throw new RuntimeException("The given two arrays should have the same length");
|}
|$keyArrayElemNullCheck
|${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy());
""".stripMargin
})
}

override def prettyName: String = "map_from_arrays"
}

/**
* An expression representing a not yet available attribute name. This expression is unevaluable
* and as its name suggests it is a temporary place holder until we're able to determine the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,50 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("MapFromArrays") {
def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
// catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
scala.collection.immutable.ListMap(keys.zip(values): _*)
}

val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25)
val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25)
val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_))

val intArray = Literal.create(intSeq, ArrayType(IntegerType, false))
val longArray = Literal.create(longSeq, ArrayType(LongType, false))
val strArray = Literal.create(strSeq, ArrayType(StringType, false))

val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true))
val intwithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intWithNullArray?

val longwithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

longWithNullArray?


val nullArray = Literal.create(null, ArrayType(StringType, false))

checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq))
checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq))
checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq))

checkEvaluation(
MapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq))
checkEvaluation(
MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq))
checkEvaluation(
MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq))
checkEvaluation(MapFromArrays(nullArray, nullArray), null)

intercept[RuntimeException] {
checkEvaluation(MapFromArrays(intwithNullArray, strArray), null)
}
intercept[RuntimeException] {
checkEvaluation(
MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null)
}
}

test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,17 @@ object functions {
@scala.annotation.varargs
def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) }

/**
* Creates a new map column. The array in the first column is used for keys. The array in the
* second column is used for values. All elements in the array for key should not be null.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and duplicated

*
* @group normal_funcs
* @since 2.4
*/
def map_from_arrays(keys: Column, values: Column): Column = withExpr {
MapFromArrays(keys.expr, values.expr)
}

/**
* Marks a DataFrame as small enough for use in broadcast joins.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,36 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(row.getMap[Int, String](0) === Map(2 -> "a"))
}

test("map with arrays") {
val df1 = Seq((Seq(1, 2), Seq("a", "b"))).toDF("k", "v")
val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
val row = df1.select(map_from_arrays($"k", $"v")).first()
assert(row.schema(0).dataType === expectedType)
assert(row.getMap[Int, String](0) === Map(1 -> "a", 2 -> "b"))
checkAnswer(df1.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> "a", 2 -> "b"))))

val df2 = Seq((Seq(1, 2), Seq(null, "b"))).toDF("k", "v")
checkAnswer(df2.select(map_from_arrays($"k", $"v")), Seq(Row(Map(1 -> null, 2 -> "b"))))

val df3 = Seq((null, null)).toDF("k", "v")
checkAnswer(df3.select(map_from_arrays($"k", $"v")), Seq(Row(null)))

val df4 = Seq((1, "a")).toDF("k", "v")
intercept[AnalysisException] {
df4.select(map_from_arrays($"k", $"v"))
}

val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v")
intercept[RuntimeException] {
df5.select(map_from_arrays($"k", $"v")).collect
}

val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v")
intercept[RuntimeException] {
df6.select(map_from_arrays($"k", $"v")).collect
}
}

test("struct with column name") {
val df = Seq((1, "str")).toDF("a", "b")
val row = df.select(struct("a", "b")).first()
Expand Down