Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
codeatri committed Aug 15, 2018
1 parent b73106d commit daf7935
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,9 @@ case class ArrayAggregate(
usage = "_FUNC_(expr, func) - Transforms values in the map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> v + 1);
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1);
map(array(1, 2, 3), array(2, 3, 4))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v);
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
map(array(1, 2, 3), array(2, 4, 6))
""",
since = "2.4.0")
Expand All @@ -517,33 +517,26 @@ case class TransformValues(

override def nullable: Boolean = argument.nullable

override def dataType: DataType = {
val map = argument.dataType.asInstanceOf[MapType]
MapType(map.keyType, function.dataType, function.nullable)
}
@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType

@transient val MapType(keyType, valueType, valueContainsNull) = argument.dataType
override def dataType: DataType = MapType(keyType, function.dataType, valueContainsNull)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction)
: TransformValues = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

@transient lazy val (keyVar, valueVar) = {
val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
(keyVar, valueVar)
}
@transient lazy val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function

override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val map = argumentValue.asInstanceOf[MapData]
val f = functionForEval
val resultValues = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
resultValues.update(i, f.eval(inputRow))
resultValues.update(i, functionForEval.eval(inputRow))
i += 1
}
new ArrayBasedMapData(map.keyArray(), resultValues)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
}

def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val valueType = expr.dataType.asInstanceOf[MapType].valueType
val keyType = expr.dataType.asInstanceOf[MapType].keyType
TransformValues(expr, createLambda(keyType, false, valueType, true, f))
val map = expr.dataType.asInstanceOf[MapType]
TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f))
}

test("ArrayTransform") {
Expand Down Expand Up @@ -292,13 +291,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
test("TransformValues") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3),
MapType(IntegerType, IntegerType))
MapType(IntegerType, IntegerType, valueContainsNull = false))
val ai1 = Literal.create(
Map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType))
val ain = Literal.create(
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai2 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType))
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1
val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k
Expand All @@ -311,13 +311,18 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(
transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(transformValues(ain, plusOne), Map.empty[Int, Int])
checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int])
checkEvaluation(transformValues(ai3, plusOne), null)

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType))
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = false))
val as1 = Literal.create(
Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), MapType(StringType, StringType))
val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType))
Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = true))
val as2 = Literal.create(Map.empty[StringType, StringType],
MapType(StringType, StringType, valueContainsNull = true))
val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val valueTypeUpdate: (Expression, Expression) => Expression =
Expand All @@ -337,13 +342,20 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(
transformValues(transformValues(as1, concatValue), concatValue),
Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx"))
checkEvaluation(transformValues(asn, concatValue), Map.empty[String, String])
checkEvaluation(transformValues(asn, valueTypeUpdate), Map.empty[String, Int])
checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String])
checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int])
checkEvaluation(
transformValues(transformValues(asn, concatValue), valueTypeUpdate),
transformValues(transformValues(as2, concatValue), valueTypeUpdate),
Map.empty[String, Int])
}

checkEvaluation(transformValues(as3, concatValue), null)

val ax0 = Literal.create(
Map(1 -> "x", 2 -> "y", 3 -> "z"),
MapType(IntegerType, StringType, valueContainsNull = false))

checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
}

test("MapZipWith") {
def map_zip_with(
left: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ select exists(ys, y -> y > 30) as v from nested;
select exists(cast(null as array<int>), y -> y > 30) as v;

create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
(1, map(1, 1, 2, 2, 3, 3)),
(2, map(4, 4, 5, 5, 6, 6))
as t(x, ys);

-- Identity Transform Keys in a map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ NULL

-- !query 17
create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
(1, map(1, 1, 2, 2, 3, 3)),
(2, map(4, 4, 5, 5, 6, 6))
as t(x, ys)
-- !query 17 schema
struct<>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2302,7 +2302,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map"))
}

test("transform values function - test various primitive data types combinations") {
test("transform values function - test primitive data types") {
val dfExample1 = Seq(
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
).toDF("i")
Expand All @@ -2316,29 +2316,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
).toDF("y")

val dfExample4 = Seq(
Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0)
Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70)
).toDF("z")

val dfExample5 = Seq(
Map[Int, Boolean](25 -> true, 26 -> false)
).toDF("a")

val dfExample6 = Seq(
Map[Int, String](25 -> "ab", 26 -> "cd")
).toDF("b")

val dfExample7 = Seq(
Map[Int, Array[Int]](1 -> Array(1, 2))
).toDF("c")

val dfExample8 = Seq(
Map[Int, Double](25 -> 26.1E0, 26 -> 31.2E0, 27 -> 37.1E0)
).toDF("d")

val dfExample10 = Seq(
Map[String, String]("s0" -> "abc", "s1" -> "def")
).toDF("f")

def testMapOfPrimitiveTypesCombination(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"),
Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14))))
Expand Down Expand Up @@ -2373,32 +2357,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3))))

checkAnswer(
dfExample5.selectExpr("transform_values(a, (k, v) -> if(v, k + 1, k + 2))"),
Seq(Row(Map(25 -> 26, 26 -> 28))))

checkAnswer(
dfExample5.selectExpr("transform_values(a, (k, v) -> v AND k = 25)"),
Seq(Row(Map(25 -> true, 26 -> false))))

checkAnswer(
dfExample5.selectExpr("transform_values(a, (k, v) -> v OR k = 26)"),
Seq(Row(Map(25 -> true, 26 -> true))))

checkAnswer(
dfExample6.selectExpr("transform_values(b, (k, v) -> k + length(v))"),
Seq(Row(Map(25 -> 27, 26 -> 28))))

checkAnswer(
dfExample7.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"),
dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"),
Seq(Row(Map(1 -> 3))))

checkAnswer(
dfExample8.selectExpr("transform_values(d, (k, v) -> CAST(v - k AS BIGINT))"),
Seq(Row(Map(25 -> 1, 26 -> 5, 27 -> 10))))

checkAnswer(
dfExample10.selectExpr("transform_values(f, (k, v) -> k || ':' || v)"),
Seq(Row(Map("s0" -> "s0:abc", "s1" -> "s1:def"))))
}

// Test with local relation, the Project will be evaluated without codegen
Expand All @@ -2408,10 +2368,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
dfExample3.cache()
dfExample4.cache()
dfExample5.cache()
dfExample6.cache()
dfExample7.cache()
dfExample8.cache()
dfExample10.cache()
// Test with cached relation, the Project will be evaluated with codegen
testMapOfPrimitiveTypesCombination()
}
Expand Down Expand Up @@ -2487,6 +2443,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Map[String, String]("a" -> "b")
).toDF("j")

val dfExample3 = Seq(
Seq(1, 2, 3, 4)
).toDF("x")

def testInvalidLambdaFunctions(): Unit = {

val ex1 = intercept[AnalysisException] {
Expand All @@ -2498,11 +2458,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)")
}
assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match"))

val ex3 = intercept[AnalysisException] {
dfExample3.selectExpr("transform_values(x, (k, v) -> k + 1)")
}
assert(ex3.getMessage.contains(
"data type mismatch: argument 1 requires map type"))
}

testInvalidLambdaFunctions()
dfExample1.cache()
dfExample2.cache()
dfExample3.cache()
testInvalidLambdaFunctions()
}

Expand Down

0 comments on commit daf7935

Please sign in to comment.