Skip to content

Commit

Permalink
[SPARK-23940][SQL] Add transform_values SQL function
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This pr adds `transform_values` function which applies the function to each entry of the map and transforms the values.
```javascript
> SELECT transform_values(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> v + 1);
       map(1->2, 2->3, 3->4)

> SELECT transform_values(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + v);
       map(1->2, 2->4, 3->6)
```
## How was this patch tested?
New Tests added to
`DataFrameFunctionsSuite`
`HigherOrderFunctionsSuite`
`SQLQueryTestSuite`

Closes #22045 from codeatri/SPARK-23940.

Authored-by: codeatri <nehapatil6@gmail.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
codeatri authored and ueshin committed Aug 17, 2018
1 parent 9251c61 commit f161409
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ object FunctionRegistry {
expression[ArrayFilter]("filter"),
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
expression[TransformValues]("transform_values"),
expression[TransformKeys]("transform_keys"),
expression[MapZipWith]("map_zip_with"),
expression[ZipWith]("zip_with"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ case class TransformKeys(
}

@transient lazy val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function


override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
Expand All @@ -550,6 +550,54 @@ case class TransformKeys(
override def prettyName: String = "transform_keys"
}

/**
* Returns a map that applies the function to each value of the map.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms values in the map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1);
map(array(1, 2, 3), array(2, 3, 4))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
map(array(1, 2, 3), array(2, 4, 6))
""",
since = "2.4.0")
case class TransformValues(
argument: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = argument.nullable

@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType

override def dataType: DataType = MapType(keyType, function.dataType, function.nullable)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction)
: TransformValues = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

@transient lazy val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function

override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val map = argumentValue.asInstanceOf[MapData]
val resultValues = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
resultValues.update(i, functionForEval.eval(inputRow))
i += 1
}
new ArrayBasedMapData(map.keyArray(), resultValues)
}

override def prettyName: String = "transform_values"
}

/**
* Merges two given maps into a single map by applying function to the pair of values with
* the same key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
aggregate(expr, zero, merge, identity)
}

def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val map = expr.dataType.asInstanceOf[MapType]
TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f))
}

test("ArrayTransform") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down Expand Up @@ -358,6 +363,74 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z"))
}

test("TransformValues") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val ai1 = Literal.create(
Map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai2 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1
val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k

checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4))
checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(
transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4))
checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(
transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int])
checkEvaluation(transformValues(ai3, plusOne), null)

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = false))
val as1 = Literal.create(
Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = true))
val as2 = Literal.create(Map.empty[StringType, StringType],
MapType(StringType, StringType, valueContainsNull = true))
val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val valueTypeUpdate: (Expression, Expression) => Expression =
(k, v) => Length(v) + 1

checkEvaluation(
transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx"))
checkEvaluation(transformValues(as0, valueTypeUpdate),
Map("a" -> 3, "bb" -> 3, "ccc" -> 3))
checkEvaluation(
transformValues(transformValues(as0, concatValue), concatValue),
Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx"))
checkEvaluation(transformValues(as1, concatValue),
Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx"))
checkEvaluation(transformValues(as1, valueTypeUpdate),
Map("a" -> 3, "bb" -> null, "ccc" -> 3))
checkEvaluation(
transformValues(transformValues(as1, concatValue), concatValue),
Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx"))
checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String])
checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int])
checkEvaluation(
transformValues(transformValues(as2, concatValue), valueTypeUpdate),
Map.empty[String, Int])
checkEvaluation(transformValues(as3, concatValue), null)

val ax0 = Literal.create(
Map(1 -> "x", 2 -> "y", 3 -> "z"),
MapType(IntegerType, StringType, valueContainsNull = false))

checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
}

test("MapZipWith") {
def map_zip_with(
left: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,12 @@ select transform_keys(ys, (k, v) -> k + 1) as v from nested;

-- Transform Keys in a map using values
select transform_keys(ys, (k, v) -> k + v) as v from nested;

-- Identity Transform values in a map
select transform_values(ys, (k, v) -> v) as v from nested;

-- Transform values in a map by adding constant
select transform_values(ys, (k, v) -> v + 1) as v from nested;

-- Transform values in a map using values
select transform_values(ys, (k, v) -> k + v) as v from nested;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 20
-- Number of queries: 27


-- !query 0
Expand Down Expand Up @@ -226,3 +226,30 @@ struct<v:map<int,int>>
-- !query 23 output
{10:5,12:6,8:4}
{2:1,4:2,6:3}


-- !query 24
select transform_values(ys, (k, v) -> v) as v from nested
-- !query 24 schema
struct<v:map<int,int>>
-- !query 24 output
{1:1,2:2,3:3}
{4:4,5:5,6:6}


-- !query 25
select transform_values(ys, (k, v) -> v + 1) as v from nested
-- !query 25 schema
struct<v:map<int,int>>
-- !query 25 output
{1:2,2:3,3:4}
{4:5,5:6,6:7}


-- !query 26
select transform_values(ys, (k, v) -> k + v) as v from nested
-- !query 26 schema
struct<v:map<int,int>>
-- !query 26 output
{1:2,2:4,3:6}
{4:8,5:10,6:12}
Loading

0 comments on commit f161409

Please sign in to comment.