Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23940][SQL] Add transform_values SQL function #22045

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ object FunctionRegistry {
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayAggregate]("aggregate"),
expression[TransformValues]("transform_values"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,61 @@ case class ArrayAggregate(

override def prettyName: String = "aggregate"
}

/**
* Transform Values for every entry of the map by applying transform_values function.
* Returns map wth transformed values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typos: Transforms values; with
Maybe can you think of a better comment?

*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms values in the map using the function.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent

examples = """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k,v) -> k + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:(k, v) and maybe I would use v + 1 instead of k + 1.

map(array(1, 2, 3), array(2, 3, 4))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v);
map(array(1, 2, 3), array(2, 4, 6))
""",
since = "2.4.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

case class TransformValues(
input: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

override def dataType: DataType = {
val map = input.dataType.asInstanceOf[MapType]
MapType(map.keyType, function.dataType, map.valueContainsNull)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map.valueContainsNull -> function.nullable?

}

override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already specified by MapBasedSimpleHigherOrderFunction.


@transient val (keyType, valueType, valueContainsNull) =
HigherOrderFunction.mapKeyValueArgumentType(input.dataType)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: formatting

TransformValues = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

@transient lazy val (keyVar, valueVar) = {
val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
(keyVar, valueVar)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about:

@transient lazy val LambdaFunction(_,
  (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function


override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val map = value.asInstanceOf[MapData]
val f = functionForEval
val resultValues = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
resultValues.update(i, f.eval(inputRow))
i += 1
}
new ArrayBasedMapData(map.keyArray(), resultValues)
}
override def prettyName: String = "transform_values"
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
aggregate(expr, zero, merge, identity)
}

def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val valueType = expr.dataType.asInstanceOf[MapType].valueType
val keyType = expr.dataType.asInstanceOf[MapType].keyType
TransformValues(expr, createLambda(keyType, false, valueType, true, f))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use valueContainsNull instead of true?

}

test("ArrayTransform") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down Expand Up @@ -230,4 +236,59 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
(acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)),
15)
}

test("TransformValues") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3),
MapType(IntegerType, IntegerType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add valueContainsNull explicitly?

val ai1 = Literal.create(
Map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType))
val ain = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add tests for Literal.create(null, MapType(IntegerType, IntegerType))?


val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1
val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k

checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4))
checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(
transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4))
checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(
transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(transformValues(ain, plusOne), Map.empty[Int, Int])

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType))
val as1 = Literal.create(
Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), MapType(StringType, StringType))
val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val valueTypeUpdate: (Expression, Expression) => Expression =
(k, v) => Length(v) + 1

checkEvaluation(
transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx"))
checkEvaluation(transformValues(as0, valueTypeUpdate),
Map("a" -> 3, "bb" -> 3, "ccc" -> 3))
checkEvaluation(
transformValues(transformValues(as0, concatValue), concatValue),
Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx"))
checkEvaluation(transformValues(as1, concatValue),
Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx"))
checkEvaluation(transformValues(as1, valueTypeUpdate),
Map("a" -> 3, "bb" -> null, "ccc" -> 3))
checkEvaluation(
transformValues(transformValues(as1, concatValue), concatValue),
Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx"))
checkEvaluation(transformValues(asn, concatValue), Map.empty[String, String])
checkEvaluation(transformValues(asn, valueTypeUpdate), Map.empty[String, Int])
checkEvaluation(
transformValues(transformValues(asn, concatValue), valueTypeUpdate),
Map.empty[String, Int])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,17 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as

-- Aggregate a null array
select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v;

create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

  (1, map(1, 1, 2, 2, 3, 3)),
  (2, map(4, 4, 5, 5, 6, 6))

as t(x, ys);

-- Identity Transform Keys in a map
select transform_values(ys, (k, v) -> v) as v from nested;

-- Transform Keys in a map by adding constant
select transform_values(ys, (k, v) -> v + 1) as v from nested;

-- Transform Keys in a map using values
select transform_values(ys, (k, v) -> k + v) as v from nested;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 18


-- !query 0
Expand Down Expand Up @@ -145,3 +145,40 @@ select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) a
struct<v:int>
-- !query 14 output
NULL


-- !query 15
create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
as t(x, ys)
-- !query 15 schema
struct<>
-- !query 15 output


-- !query 16
select transform_values(ys, (k, v) -> v) as v from nested
-- !query 16 schema
struct<v:map<int,int>>
-- !query 16 output
{1:1,2:2,3:3}
{4:4,5:5,6:6}


-- !query 17
select transform_values(ys, (k, v) -> v + 1) as v from nested
-- !query 17 schema
struct<v:map<int,int>>
-- !query 17 output
{1:2,2:3,3:4}
{4:5,5:6,6:7}


-- !query 18
select transform_values(ys, (k, v) -> k + v) as v from nested
-- !query 18 schema
struct<v:map<int,int>>
-- !query 18 output
{1:2,2:4,3:6}
{4:8,5:10,6:12}
Loading