Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23938][SQL] Add map_zip_with function #22017

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ object FunctionRegistry {
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayAggregate]("aggregate"),
expression[MapZipWith]("map_zip_with"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ object TypeCoercion {
})
}

/**
* Similar to [[findTightestCommonType]] but with string promotion.
*/
def findWiderTypeForTwoExceptDecimals(t1: DataType, t2: DataType): Option[DataType] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why except Decimals?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have maps with decimals of different precision as keys. Cast will fail in analysis phase since it can't cast a key to nullable (potential lost of precision). IMHO, the type mismatch exception from this function will be more accurate. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, good catch! But it led me to another issue. We can't choose those types possibly to be null as a map key. Instead of adding the method, how about modifying findTypeForComplex as something like:

private def findTypeForComplex(
      t1: DataType,
      t2: DataType,
      findTypeFunc: (DataType, DataType) => Option[DataType]): Option[DataType] = (t1, t2) match {
  ...
    case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
      findTypeFunc(kt1, kt2)
        .filter(kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt))
        .flatMap { kt =>
          findTypeFunc(vt1, vt2).map { vt =>
            MapType(kt, vt, valueContainsNull1 || valueContainsNull2)
          }
      }
  ...
}

We might need to have another pr to discuss this.

cc @cloud-fan @gatorsmile

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thoughts, do we really need those? Seems like the current coercions rules don't contain possibly cast to null?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I see that this is a matter of findTypeForComplex. I'll submit another pr later. Maybe we can go back to findWiderTypeForTwo in TypeCoercion and findCommonTypeDifferentOnlyInNullFlag for keyType.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I submitted a pr #22086.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for both your PRs! I will submit changes once they get in.

findTightestCommonType(t1, t2)
.orElse(stringPromotion(t1, t2))
.orElse(findTypeForComplex(t1, t2, findWiderTypeForTwoExceptDecimals))
}

/**
* Similar to [[findWiderTypeForTwo]] that can handle decimal types, but can't promote to
* string. If the wider decimal type exceeds system limitation, this rule will truncate
Expand Down Expand Up @@ -602,6 +611,20 @@ object TypeCoercion {

CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })

case m @ MapZipWith(left, right, function) if MapType.acceptsType(left.dataType) &&
MapType.acceptsType(right.dataType) && !m.leftKeyType.sameType(m.rightKeyType) =>
findWiderTypeForTwoExceptDecimals(m.leftKeyType, m.rightKeyType) match {
case Some(finalKeyType) =>
val newLeft = castIfNotSameType(
left,
MapType(finalKeyType, m.leftValueType, m.leftValueContainsNull))
val newRight = castIfNotSameType(
right,
MapType(finalKeyType, m.rightValueType, m.rightValueContainsNull))
MapZipWith(newLeft, newRight, function)
case None => m
}

// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods

/**
* A named lambda variable.
Expand Down Expand Up @@ -442,3 +442,186 @@ case class ArrayAggregate(

override def prettyName: String = "aggregate"
}

/**
* Merges two given maps into a single map by applying function to the pair of values with
* the same key.
*/
@ExpressionDescription(
usage =
"""
_FUNC_(map1, map2, function) - Merges two given maps into a single map by applying
function to the pair of values with the same key. For keys only presented in one map,
NULL will be passed as the value for the missing key. If an input map contains duplicated
keys, only the first entry of the duplicated key is passed into the lambda function.
""",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2));
{1:"ax",2:"by"}
""",
since = "2.4.0")
case class MapZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {

@transient lazy val functionForEval: Expression = functionsForEval.head
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shall we use def here to follow the comment #21954 (comment)?


@transient lazy val (leftKeyType, leftValueType, leftValueContainsNull) =
HigherOrderFunction.mapKeyValueArgumentType(left.dataType)

@transient lazy val (rightKeyType, rightValueType, rightValueContainsNull) =
HigherOrderFunction.mapKeyValueArgumentType(right.dataType)

@transient lazy val keyType =
TypeCoercion.findWiderTypeForTwoExceptDecimals(leftKeyType, rightKeyType).getOrElse(NullType)

@transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType)

override def inputs: Seq[Expression] = left :: right :: Nil

override def functions: Seq[Expression] = function :: Nil

override def nullable: Boolean = left.nullable || right.nullable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left.nullable && right.nullable? Because if one side is empty map, NULL will be passed as the value for each key in other side.

Copy link
Contributor Author

@mn-mikke mn-mikke Aug 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nullable flag is rather related to the cases when the whole map is null. The case that you are referring to is handled by valueContainsNull flag of MapType (see the line 496).


override def dataType: DataType = MapType(keyType, function.dataType, function.nullable)

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (MapType(k1, _, _), MapType(k2, _, _)) if k1.sameType(k2) =>
TypeUtils.checkForOrderingExpr(k1, s"function $prettyName")
case _ => TypeCheckResult.TypeCheckFailure(s"The input to function $prettyName should have " +
s"been two ${MapType.simpleString}s with compatible key types, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
}
}

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapZipWith = {
val arguments = Seq((keyType, false), (leftValueType, true), (rightValueType, true))
copy(function = f(function, arguments))
}

override def eval(input: InternalRow): Any = {
val value1 = left.eval(input)
if (value1 == null) {
null
} else {
val value2 = right.eval(input)
if (value2 == null) {
null
} else {
nullSafeEval(input, value1, value2)
}
}
}

@transient lazy val LambdaFunction(_, Seq(
keyVar: NamedLambdaVariable,
value1Var: NamedLambdaVariable,
value2Var: NamedLambdaVariable),
_) = function

private def keyTypeSupportsEquals = keyType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}

/**
* The function accepts two key arrays and returns a collection of keys with indexes
* to value arrays. Indexes are represented as an array of two items. This is a small
* optimization leveraging mutability of arrays.
*/
@transient private lazy val getKeysWithValueIndexes:
(ArrayData, ArrayData) => mutable.Iterable[(Any, Array[Option[Int]])] = {
if (keyTypeSupportsEquals) {
getKeysWithIndexesFast
} else {
getKeysWithIndexesBruteForce
}
}

private def assertSizeOfArrayBuffer(size: Int): Unit = {
if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to zip maps with $size " +
s"unique keys due to exceeding the array size limit " +
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
}

private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = {
val hashMap = new mutable.LinkedHashMap[Any, Array[Option[Int]]]
for((z, array) <- Array((0, keys1), (1, keys2))) {
var i = 0
while (i < array.numElements()) {
val key = array.get(i, keyType)
hashMap.get(key) match {
case Some(indexes) =>
if (indexes(z).isEmpty) {
indexes(z) = Some(i)
}
case None =>
val indexes = Array[Option[Int]](None, None)
indexes(z) = Some(i)
hashMap.put(key, indexes)
}
i += 1
}
}
hashMap
}

private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: ArrayData) = {
val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])]
for((z, array) <- Array((0, keys1), (1, keys2))) {
var i = 0
while (i < array.numElements()) {
val key = array.get(i, keyType)
var found = false
var j = 0
while (!found && j < arrayBuffer.size) {
val (bufferKey, indexes) = arrayBuffer(j)
if (ordering.equiv(bufferKey, key)) {
found = true
if(indexes(z).isEmpty) {
indexes(z) = Some(i)
}
}
j += 1
}
if (!found) {
assertSizeOfArrayBuffer(arrayBuffer.size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we check this only once at the end in order to avoid the overhead at each iteration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this line is to avoid OutOfMemoryError exception when max array size is exceeded and throw something more accurate. Maybe I'm missing something, but wouldn't we break it we checked this only once at the end? The max size could be exceeded in any iteration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, because you are using an ArrayBuffer....makes sense, thanks

val indexes = Array[Option[Int]](None, None)
indexes(z) = Some(i)
arrayBuffer += Tuple2(key, indexes)
}
i += 1
}
}
arrayBuffer
}

private def nullSafeEval(inputRow: InternalRow, value1: Any, value2: Any): Any = {
val mapData1 = value1.asInstanceOf[MapData]
val mapData2 = value2.asInstanceOf[MapData]
val keysWithIndexes = getKeysWithValueIndexes(mapData1.keyArray(), mapData2.keyArray())
val size = keysWithIndexes.size
val keys = new GenericArrayData(new Array[Any](size))
val values = new GenericArrayData(new Array[Any](size))
val valueData1 = mapData1.valueArray()
val valueData2 = mapData2.valueArray()
var i = 0
for ((key, Array(index1, index2)) <- keysWithIndexes) {
val v1 = index1.map(valueData1.get(_, leftValueType)).getOrElse(null)
val v2 = index2.map(valueData2.get(_, rightValueType)).getOrElse(null)
keyVar.value.set(key)
value1Var.value.set(v1)
value2Var.value.set(v2)
keys.update(i, key)
values.update(i, functionForEval.eval(inputRow))
i += 1
}
new ArrayBasedMapData(keys, values)
}

override def prettyName: String = "map_zip_with"
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
LambdaFunction(function, Seq(lv1, lv2))
}

private def createLambda(
dt1: DataType,
nullable1: Boolean,
dt2: DataType,
nullable2: Boolean,
dt3: DataType,
nullable3: Boolean,
f: (Expression, Expression, Expression) => Expression): Expression = {
val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
val lv3 = NamedLambdaVariable("arg3", dt3, nullable3)
val function = f(lv1, lv2, lv3)
LambdaFunction(function, Seq(lv1, lv2, lv3))
}

def transform(expr: Expression, f: Expression => Expression): Expression = {
val at = expr.dataType.asInstanceOf[ArrayType]
ArrayTransform(expr, createLambda(at.elementType, at.containsNull, f))
Expand Down Expand Up @@ -230,4 +245,118 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
(acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)),
15)
}

test("MapZipWith") {
def map_zip_with(
left: Expression,
right: Expression,
f: (Expression, Expression, Expression) => Expression): Expression = {
val MapType(kt, vt1, vcn1) = left.dataType.asInstanceOf[MapType]
val MapType(_, vt2, vcn2) = right.dataType.asInstanceOf[MapType]
MapZipWith(left, right, createLambda(kt, false, vt1, vcn1, vt2, vcn2, f))
}

val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, valueContainsNull = false))
val mii4 = MapFromArrays(
Literal.create(Seq(2, 2), ArrayType(IntegerType, false)),
Literal.create(Seq(20, 200), ArrayType(IntegerType, false)))
val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val multiplyKeyWithValues: (Expression, Expression, Expression) => Expression = {
(k, v1, v2) => k * v1 * v2
}

checkEvaluation(
map_zip_with(mii0, mii1, multiplyKeyWithValues),
Map(1 -> -10, 2 -> -80, 3 -> null, 4 -> null))
checkEvaluation(
map_zip_with(mii0, mii2, multiplyKeyWithValues),
Map(1 -> null, 2 -> -80, 3 -> null))
checkEvaluation(
map_zip_with(mii0, mii3, multiplyKeyWithValues),
Map(1 -> null, 2 -> null, 3 -> null))
checkEvaluation(
map_zip_with(mii0, mii4, multiplyKeyWithValues),
Map(1 -> null, 2 -> 800, 3 -> null))
checkEvaluation(
map_zip_with(mii4, mii0, multiplyKeyWithValues),
Map(2 -> 800, 1 -> null, 3 -> null))
checkEvaluation(
map_zip_with(mii0, miin, multiplyKeyWithValues),
null)

val mss0 = Literal.create(Map("a" -> "x", "b" -> "y", "d" -> "z"),
MapType(StringType, StringType, valueContainsNull = false))
val mss1 = Literal.create(Map("d" -> "b", "b" -> "d"),
MapType(StringType, StringType, valueContainsNull = false))
val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val mss3 = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false))
val mss4 = MapFromArrays(
Literal.create(Seq("a", "a"), ArrayType(StringType, false)),
Literal.create(Seq("a", "n"), ArrayType(StringType, false)))
val mssn = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false))

val concat: (Expression, Expression, Expression) => Expression = {
(k, v1, v2) => Concat(Seq(k, v1, v2))
}

checkEvaluation(
map_zip_with(mss0, mss1, concat),
Map("a" -> null, "b" -> "byd", "d" -> "dzb"))
checkEvaluation(
map_zip_with(mss1, mss2, concat),
Map("d" -> null, "b" -> "bdt", "c" -> null, "a" -> null))
checkEvaluation(
map_zip_with(mss0, mss3, concat),
Map("a" -> null, "b" -> null, "d" -> null))
checkEvaluation(
map_zip_with(mss0, mss4, concat),
Map("a" -> "axa", "b" -> null, "d" -> null))
checkEvaluation(
map_zip_with(mss4, mss0, concat),
Map("a" -> "aax", "b" -> null, "d" -> null))
checkEvaluation(
map_zip_with(mss0, mssn, concat),
null)

def b(data: Byte*): Array[Byte] = Array[Byte](data: _*)

val mbb0 = Literal.create(Map(b(1, 2) -> b(4), b(2, 1) -> b(5), b(1, 3) -> b(8)),
MapType(BinaryType, BinaryType, valueContainsNull = false))
val mbb1 = Literal.create(Map(b(2, 1) -> b(7), b(1, 2) -> b(3), b(1, 1) -> b(6)),
MapType(BinaryType, BinaryType, valueContainsNull = false))
val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> null),
MapType(BinaryType, BinaryType, valueContainsNull = true))
val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, valueContainsNull = false))
val mbb4 = MapFromArrays(
Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)),
Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false)))
val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, valueContainsNull = false))

checkEvaluation(
map_zip_with(mbb0, mbb1, concat),
Map(b(1, 2) -> b(1, 2, 4, 3), b(2, 1) -> b(2, 1, 5, 7), b(1, 3) -> null, b(1, 1) -> null))
checkEvaluation(
map_zip_with(mbb1, mbb2, concat),
Map(b(2, 1) -> null, b(1, 2) -> b(1, 2, 3, 2), b(1, 1) -> null, b(1, 3) -> null))
checkEvaluation(
map_zip_with(mbb0, mbb3, concat),
Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null))
checkEvaluation(
map_zip_with(mbb0, mbb4, concat),
Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null))
checkEvaluation(
map_zip_with(mbb4, mbb0, concat),
Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null))
checkEvaluation(
map_zip_with(mbb0, mbbn, concat),
null)
}
}
Loading