From 514d4800b68847b571ddab52e002795d9b673db3 Mon Sep 17 00:00:00 2001 From: darionyaphet Date: Tue, 13 Jun 2017 16:38:20 +0800 Subject: [PATCH] [SPARK-21073] Support map_keys and map_values functions in DataSet --- .../org/apache/spark/sql/functions.scala | 18 +++++++++++++ .../spark/sql/CollectionFunctionsSuit.scala | 26 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CollectionFunctionsSuit.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8d2e1f32da05..43961cfd4a17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1006,6 +1006,24 @@ object functions { @scala.annotation.varargs def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } + /** + * Returns the keys from map as a array. + * + * @group collection_funcs + */ + def map_keys(column: Column): Column = withExpr { + MapKeys(column.expr) + } + + /** + * Returns the values from map as a array. + * + * @group collection_funcs + */ + def map_values(column: Column): Column = withExpr { + MapValues(column.expr) + } + /** * Marks a DataFrame as small enough for use in broadcast joins. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollectionFunctionsSuit.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollectionFunctionsSuit.scala new file mode 100644 index 000000000000..1b831b6028a8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollectionFunctionsSuit.scala @@ -0,0 +1,26 @@ +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +import org.apache.spark.sql.functions._ + +class CollectionFunctionsSuit extends QueryTest with SharedSQLContext { + + import testImplicits._ + + test("functions map_keys") { + val df = Seq(("1", 1), ("2", 2), ("3", 3)).toDF("col0", "col1") + val mapDF = df.select(map($"col0", $"col1").as("column0")) + checkAnswer( + mapDF.select(map_keys($"column0")), + Seq(Row(Array("1")), Row(Array("2")), Row(Array("3")))) + } + + test("function map_values") { + val df = Seq(("1", 1), ("2", 2), ("3", 3)).toDF("col0", "col1") + val mapDF = df.select(map($"col0", $"col1").as("column0")) + checkAnswer( + mapDF.select(map_values($"column0")), + Seq(Row(Array(1)), Row(Array(2)), Row(Array(3)))) + } +}