-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23933][SQL] Add map_from_arrays function #21258
Changes from 12 commits
e4171e1
95d92d8
1df6bb5
2075770
d5ff7be
4eee89d
7b66ab4
2fcbb80
228fcc6
6d53a96
a4b3ec2
a0b4ac5
38d0868
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder | |
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult | ||
import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} | ||
import org.apache.spark.sql.catalyst.util._ | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.Platform | ||
import org.apache.spark.unsafe.array.ByteArrayMethods | ||
|
@@ -236,6 +236,76 @@ case class CreateMap(children: Seq[Expression]) extends Expression { | |
override def prettyName: String = "map" | ||
} | ||
|
||
/** | ||
* Returns a catalyst Map containing the two arrays in children expressions as keys and values. | ||
*/ | ||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(keys, values) - Creates a map with a pair of the given key/value arrays. All elements | ||
in keys should not be null""", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and duplicated. |
||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_([1.0, 3.0], ['2', '4']); | ||
{1.0:"2",3.0:"4"} | ||
""", since = "2.4.0") | ||
case class MapFromArrays(left: Expression, right: Expression) | ||
extends BinaryExpression with ExpectsInputTypes { | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType) | ||
|
||
override def dataType: DataType = { | ||
MapType( | ||
keyType = left.dataType.asInstanceOf[ArrayType].elementType, | ||
valueType = right.dataType.asInstanceOf[ArrayType].elementType, | ||
valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) | ||
} | ||
|
||
override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { | ||
val keyArrayData = keyArray.asInstanceOf[ArrayData] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't you detect duplicities first? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please let us know where this specification is described or is derived from? It is not written here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although it's not specified, duplicated key can lead to non-determinism of returned values in future. Currently, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. According to current Spark implementation, for example, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we don't have to change it now. But I would like to agree on a consistent approach for the new functions, since this is also related to SPARK-23934 and SPARK-23936. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to err on the safe side here. |
||
val valueArrayData = valueArray.asInstanceOf[ArrayData] | ||
if (keyArrayData.numElements != valueArrayData.numElements) { | ||
throw new RuntimeException("The given two arrays should have the same length") | ||
} | ||
val leftArrayType = left.dataType.asInstanceOf[ArrayType] | ||
if (leftArrayType.containsNull) { | ||
var i = 0 | ||
while (i < keyArrayData.numElements) { | ||
if (keyArrayData.isNullAt(i)) { | ||
throw new RuntimeException("Cannot use null as map key!") | ||
} | ||
i += 1 | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can use loop to null-check without converting to object array? |
||
} | ||
new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) | ||
} | ||
|
||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { | ||
val arrayBasedMapData = classOf[ArrayBasedMapData].getName | ||
val leftArrayType = left.dataType.asInstanceOf[ArrayType] | ||
val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { | ||
val i = ctx.freshName("i") | ||
s""" | ||
|for (int $i = 0; $i < $keyArrayData.numElements(); $i++) { | ||
| if ($keyArrayData.isNullAt($i)) { | ||
| throw new RuntimeException("Cannot use null as map key!"); | ||
| } | ||
|} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can null-check without converting to object array. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch, thanks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. However, I realized we have to evaluate each element as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm sorry, but I couldn't get it. I might miss something, but I thought we can simply do like:
Doesn't this work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code should work if we evaluate each element to make I think that my mistake is not to currently evaluate each element in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. An array has been evaluated. |
||
""".stripMargin | ||
} | ||
s""" | ||
|if ($keyArrayData.numElements() != $valueArrayData.numElements()) { | ||
| throw new RuntimeException("The given two arrays should have the same length"); | ||
|} | ||
|$keyArrayElemNullCheck | ||
|${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); | ||
""".stripMargin | ||
}) | ||
} | ||
|
||
override def prettyName: String = "map_from_arrays" | ||
} | ||
|
||
/** | ||
* An expression representing a not yet available attribute name. This expression is unevaluable | ||
* and as its name suggests it is a temporary place holder until we're able to determine the | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,6 +186,50 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { | |
} | ||
} | ||
|
||
test("MapFromArrays") { | ||
def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { | ||
// catalyst map is order-sensitive, so we create ListMap here to preserve the elements order. | ||
scala.collection.immutable.ListMap(keys.zip(values): _*) | ||
} | ||
|
||
val intSeq = Seq(5, 10, 15, 20, 25) | ||
val longSeq = intSeq.map(_.toLong) | ||
val strSeq = intSeq.map(_.toString) | ||
val integerSeq = Seq[java.lang.Integer](5, 10, 15, 20, 25) | ||
val intWithNullSeq = Seq[java.lang.Integer](5, 10, null, 20, 25) | ||
val longWithNullSeq = intSeq.map(java.lang.Long.valueOf(_)) | ||
|
||
val intArray = Literal.create(intSeq, ArrayType(IntegerType, false)) | ||
val longArray = Literal.create(longSeq, ArrayType(LongType, false)) | ||
val strArray = Literal.create(strSeq, ArrayType(StringType, false)) | ||
|
||
val integerArray = Literal.create(integerSeq, ArrayType(IntegerType, true)) | ||
val intwithNullArray = Literal.create(intWithNullSeq, ArrayType(IntegerType, true)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
val longwithNullArray = Literal.create(longWithNullSeq, ArrayType(LongType, true)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
val nullArray = Literal.create(null, ArrayType(StringType, false)) | ||
|
||
checkEvaluation(MapFromArrays(intArray, longArray), createMap(intSeq, longSeq)) | ||
checkEvaluation(MapFromArrays(intArray, strArray), createMap(intSeq, strSeq)) | ||
checkEvaluation(MapFromArrays(integerArray, strArray), createMap(integerSeq, strSeq)) | ||
|
||
checkEvaluation( | ||
MapFromArrays(strArray, intwithNullArray), createMap(strSeq, intWithNullSeq)) | ||
checkEvaluation( | ||
MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) | ||
checkEvaluation( | ||
MapFromArrays(strArray, longwithNullArray), createMap(strSeq, longWithNullSeq)) | ||
checkEvaluation(MapFromArrays(nullArray, nullArray), null) | ||
|
||
intercept[RuntimeException] { | ||
checkEvaluation(MapFromArrays(intwithNullArray, strArray), null) | ||
} | ||
intercept[RuntimeException] { | ||
checkEvaluation( | ||
MapFromArrays(intArray, Literal.create(Seq(1), ArrayType(IntegerType))), null) | ||
} | ||
} | ||
|
||
test("CreateStruct") { | ||
val row = create_row(1, 2, 3) | ||
val c1 = 'a.int.at(0) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1070,6 +1070,17 @@ object functions { | |
@scala.annotation.varargs | ||
def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } | ||
|
||
/** | ||
* Creates a new map column. The array in the first column is used for keys. The array in the | ||
* second column is used for values. All elements in the array for key should not be null. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and duplicated |
||
* | ||
* @group normal_funcs | ||
* @since 2.4 | ||
*/ | ||
def map_from_arrays(keys: Column, values: Column): Column = withExpr { | ||
MapFromArrays(keys.expr, values.expr) | ||
} | ||
|
||
/** | ||
* Marks a DataFrame as small enough for use in broadcast joins. | ||
* | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and duplicated?