-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23938][SQL] Add map_zip_with function #22017
Changes from 13 commits
ef56011
34cdf0d
ec583eb
be2e10d
a9020a4
89a3da4
12ad8b2
6aeaaa8
38ce4e7
562ee81
5d2a78e
3c849cb
595161f
6995f1e
2b7e9e5
bcd4e0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,11 +22,11 @@ import java.util.concurrent.atomic.AtomicReference | |
import scala.collection.mutable | ||
|
||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} | ||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute} | ||
import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} | ||
import org.apache.spark.sql.catalyst.util._ | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.array.ByteArrayMethods | ||
|
||
/** | ||
* A named lambda variable. | ||
|
@@ -442,3 +442,186 @@ case class ArrayAggregate( | |
|
||
override def prettyName: String = "aggregate" | ||
} | ||
|
||
/** | ||
* Merges two given maps into a single map by applying function to the pair of values with | ||
* the same key. | ||
*/ | ||
@ExpressionDescription( | ||
usage = | ||
""" | ||
_FUNC_(map1, map2, function) - Merges two given maps into a single map by applying | ||
function to the pair of values with the same key. For keys only presented in one map, | ||
NULL will be passed as the value for the missing key. If an input map contains duplicated | ||
keys, only the first entry of the duplicated key is passed into the lambda function. | ||
""", | ||
examples = """ | ||
Examples: | ||
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2)); | ||
{1:"ax",2:"by"} | ||
""", | ||
since = "2.4.0") | ||
case class MapZipWith(left: Expression, right: Expression, function: Expression) | ||
extends HigherOrderFunction with CodegenFallback { | ||
|
||
@transient lazy val functionForEval: Expression = functionsForEval.head | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: shall we use |
||
|
||
@transient lazy val (leftKeyType, leftValueType, leftValueContainsNull) = | ||
HigherOrderFunction.mapKeyValueArgumentType(left.dataType) | ||
|
||
@transient lazy val (rightKeyType, rightValueType, rightValueContainsNull) = | ||
HigherOrderFunction.mapKeyValueArgumentType(right.dataType) | ||
|
||
@transient lazy val keyType = | ||
TypeCoercion.findWiderTypeForTwoExceptDecimals(leftKeyType, rightKeyType).getOrElse(NullType) | ||
|
||
@transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType) | ||
|
||
override def inputs: Seq[Expression] = left :: right :: Nil | ||
|
||
override def functions: Seq[Expression] = function :: Nil | ||
|
||
override def nullable: Boolean = left.nullable || right.nullable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
override def dataType: DataType = MapType(keyType, function.dataType, function.nullable) | ||
|
||
override def checkInputDataTypes(): TypeCheckResult = { | ||
(left.dataType, right.dataType) match { | ||
case (MapType(k1, _, _), MapType(k2, _, _)) if k1.sameType(k2) => | ||
TypeUtils.checkForOrderingExpr(k1, s"function $prettyName") | ||
case _ => TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " + | ||
s"been two ${MapType.simpleString}s with compatible key types, but it's " + | ||
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].") | ||
} | ||
} | ||
|
||
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = { | ||
val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true)) | ||
copy(function = f(function, arguments)) | ||
} | ||
|
||
override def eval(input: InternalRow): Any = { | ||
val value1 = left.eval(input) | ||
if (value1 == null) { | ||
null | ||
} else { | ||
val value2 = right.eval(input) | ||
if (value2 == null) { | ||
null | ||
} else { | ||
nullSafeEval(input, value1, value2) | ||
} | ||
} | ||
} | ||
|
||
@transient lazy val LambdaFunction(_, Seq( | ||
keyVar: NamedLambdaVariable, | ||
value1Var: NamedLambdaVariable, | ||
value2Var: NamedLambdaVariable), | ||
_) = function | ||
|
||
private def keyTypeSupportsEquals = keyType match { | ||
case BinaryType => false | ||
case _: AtomicType => true | ||
case _ => false | ||
} | ||
|
||
/** | ||
* The function accepts two key arrays and returns a collection of keys with indexes | ||
* to value arrays. Indexes are represented as an array of two items. This is a small | ||
* optimization leveraging mutability of arrays. | ||
*/ | ||
@transient private lazy val getKeysWithValueIndexes: | ||
(ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = { | ||
if (keyTypeSupportsEquals) { | ||
getKeysWithIndexesFast | ||
} else { | ||
getKeysWithIndexesBruteForce | ||
} | ||
} | ||
|
||
private def assertSizeOfArrayBuffer(size: Int): Unit = { | ||
if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { | ||
throw new RuntimeException(s"Unsuccessful try to zip maps with $size " + | ||
s"unique keys due to exceeding the array size limit " + | ||
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") | ||
} | ||
} | ||
|
||
private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = { | ||
val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]] | ||
for((z, array) <- Array((0, keys1), (1, keys2))) { | ||
var i = 0 | ||
while (i < array.numElements()) { | ||
val key = array.get(i, keyType) | ||
hashMap.get(key) match { | ||
case Some(indexes) => | ||
if (indexes(z).isEmpty) { | ||
indexes(z) = Some(i) | ||
} | ||
case None => | ||
val indexes = Array[Option[Int]](None, None) | ||
indexes(z) = Some(i) | ||
hashMap.put(key, indexes) | ||
} | ||
i += 1 | ||
} | ||
} | ||
hashMap | ||
} | ||
|
||
private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = { | ||
val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])] | ||
for((z, array) <- Array((0, keys1), (1, keys2))) { | ||
var i = 0 | ||
while (i < array.numElements()) { | ||
val key = array.get(i, keyType) | ||
var found = false | ||
var j = 0 | ||
while (!found && j < arrayBuffer.size) { | ||
val (bufferKey, indexes) = arrayBuffer(j) | ||
if (ordering.equiv(bufferKey, key)) { | ||
found = true | ||
if(indexes(z).isEmpty) { | ||
indexes(z) = Some(i) | ||
} | ||
} | ||
j += 1 | ||
} | ||
if (!found) { | ||
assertSizeOfArrayBuffer(arrayBuffer.size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we check this only once at the end in order to avoid the overhead at each iteration? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The purpose of this line is to avoid There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, because you are using an ArrayBuffer....makes sense, thanks |
||
val indexes = Array[Option[Int]](None, None) | ||
indexes(z) = Some(i) | ||
arrayBuffer += Tuple2(key, indexes) | ||
} | ||
i += 1 | ||
} | ||
} | ||
arrayBuffer | ||
} | ||
|
||
private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = { | ||
val mapData1 = value1.asInstanceOf[MapData] | ||
val mapData2 = value2.asInstanceOf[MapData] | ||
val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray()) | ||
val size = keysWithIndexes.size | ||
val keys = new GenericArrayData(new Array[Any](size)) | ||
val values = new GenericArrayData(new Array[Any](size)) | ||
val valueData1 = mapData1.valueArray() | ||
val valueData2 = mapData2.valueArray() | ||
var i = 0 | ||
for ((key, Array(index1, index2)) <- keysWithIndexes) { | ||
val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null) | ||
val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null) | ||
keyVar.value.set(key) | ||
value1Var.value.set(v1) | ||
value2Var.value.set(v2) | ||
keys.update(i, key) | ||
values.update(i, functionForEval.eval(inputRow)) | ||
i += 1 | ||
} | ||
new ArrayBasedMapData(keys, values) | ||
} | ||
|
||
override def prettyName: String = "map_zip_with" | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why except Decimals?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have maps with decimals of different precision as keys.
Cast
will fail in analysis phase since it can't cast a key to nullable (potential lost of precision). IMHO, the type mismatch exception from this function will be more accurate. WDYT?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see, good catch! But it led me to another issue. We can't choose those types possibly to be null as a map key. Instead of adding the method, how about modifying
findTypeForComplex
as something like:We might need to have another pr to discuss this.
cc @cloud-fan @gatorsmile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thoughts, do we really need those? Seems like the current coercions rules don't contain possibly cast to null?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I see that this is a matter of
findTypeForComplex
. I'll submit another pr later. Maybe we can go back tofindWiderTypeForTwo
inTypeCoercion
andfindCommonTypeDifferentOnlyInNullFlag
forkeyType
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I submitted a pr #22086.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for both your PRs! I will submit changes once they get in.