From ba8854c3a9af18017a62116d9118aabe1847290e Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 11 Feb 2015 17:37:45 -0800 Subject: [PATCH 1/5] [SPARK-5573][SQL] Add explode to dataframes --- .../sql/catalyst/expressions/generators.scala | 16 +++++++++++++ .../org/apache/spark/sql/DataFrame.scala | 23 +++++++++++++++++++ .../org/apache/spark/sql/DataFrameImpl.scala | 16 +++++++++++-- .../apache/spark/sql/IncomputableColumn.scala | 6 ++++- .../org/apache/spark/sql/DataFrameSuite.scala | 16 +++++++++++++ 5 files changed, 74 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 43b6482c0171c..c63ffaf2435b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -73,6 +73,22 @@ abstract class Generator extends Expression { } } +case class UserDefinedGenerator( + schema: Seq[Attribute], + function: Row => TraversableOnce[Row], + children: Seq[Expression]) + extends Generator{ + + override protected def makeOutput(): Seq[Attribute] = schema + + override def eval(input: Row): TraversableOnce[Row] = { + val inputRow = new InterpretedProjection(children) + function(inputRow(input)) + } + + override def toString = s"UserDefinedGenerator(${children.mkString(",")})" +} + /** * Given an input array produces a sequence of rows for each value in the array. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 17900c5ee3892..be6f47770d1f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.annotation.{DeveloperApi, Experimental} @@ -458,6 +459,28 @@ trait DataFrame extends RDDApi[Row] { sample(withReplacement, fraction, Utils.random.nextLong) } + /** + * Returns a new [[DataFrame]] where each row has been expanded to zero or more rows by the + * provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of the + * input row are implicitly joined with each row that is output by the function. + * + * The following examples use this function to count the number of books which contain + * a given word. + * + * {{{ + * case class Book(title: String, words: String) + * val df: RDD[Book] + * + * case class Word(word: String) + * val allWords = df.explode('words) { + * case Row(words: String) => words.split(" ").map(Word(_)) + * } + * + * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title")) + * }}} + */ + def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame + ///////////////////////////////////////////////////////////////////////////// /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 41da4424ae459..2827fd940f9d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -21,6 +21,7 @@ import java.io.CharArrayWriter import scala.language.implicitConversions import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.collection.JavaConversions._ import com.fasterxml.jackson.core.JsonFactory @@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} +import org.apache.spark.sql.catalyst.{expressions, SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} @@ -39,7 +40,6 @@ import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{NumericType, StructType} - /** * Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly. */ @@ -292,6 +292,18 @@ private[sql] class DataFrameImpl protected[sql]( Sample(fraction, withReplacement, seed, logicalPlan) } + override def explode[A <: Product : TypeTag] + (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributes = schema.toAttributes + val rowFunction = + f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])) + + val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) + Generate(generator, join = true, outer = false, None, logicalPlan) + } + + ///////////////////////////////////////////////////////////////////////////// // RDD API ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index 494e49c1317b6..7e52031b380c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD @@ -114,7 +115,10 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err() - ///////////////////////////////////////////////////////////////////////////// + override def explode[A <: Product : TypeTag] + (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err() + + ///////////////////////////////////////////////////////////////////////////// override def head(n: Int): Array[Row] = err() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 7be9215a443f0..e8abef1224dfa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -98,6 +98,22 @@ class DataFrameSuite extends QueryTest { sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) } + test("explode") { + val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters") + val df2 = + df.explode('letters) { + case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq + } + + checkAnswer( + df2 + .select('_1 as 'letter, 'number) + .groupBy('letter) + .agg('letter, countDistinct('number)), + Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil + ) + } + test("selectExpr") { checkAnswer( testData.selectExpr("abs(key)", "value"), From 950707afbe59c976ed3ad044acfe31b7752eaa98 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 11 Feb 2015 17:40:25 -0800 Subject: [PATCH 2/5] fix comments --- .../apache/spark/sql/catalyst/expressions/generators.scala | 3 +++ sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/DataFrameImpl.scala | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index c63ffaf2435b3..0983d274def3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -73,6 +73,9 @@ abstract class Generator extends Expression { } } +/** + * A generator that produces its output using the provided lambda function. + */ case class UserDefinedGenerator( schema: Seq[Attribute], function: Row => TraversableOnce[Row], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index be6f47770d1f3..feac63f45bc8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -464,8 +464,8 @@ trait DataFrame extends RDDApi[Row] { * provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of the * input row are implicitly joined with each row that is output by the function. * - * The following examples use this function to count the number of books which contain - * a given word. + * The following example uses this function to count the number of books which contain + * a given word: * * {{{ * case class Book(title: String, words: String) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 2827fd940f9d0..d3564f0839a1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -298,8 +298,8 @@ private[sql] class DataFrameImpl protected[sql]( val attributes = schema.toAttributes val rowFunction = f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row])) - val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) + Generate(generator, join = true, outer = false, None, logicalPlan) } From d633d0162ac034e4d1b642980a152354ef452f57 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 11 Feb 2015 17:42:56 -0800 Subject: [PATCH 3/5] add scala specific --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index feac63f45bc8a..f1322e59a5956 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -460,9 +460,9 @@ trait DataFrame extends RDDApi[Row] { } /** - * Returns a new [[DataFrame]] where each row has been expanded to zero or more rows by the - * provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of the - * input row are implicitly joined with each row that is output by the function. + * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more + * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of + * the input row are implicitly joined with each row that is output by the function. * * The following example uses this function to count the number of books which contain * a given word: From dc86a5c7ec71f9a6b005175350d98a90de8d63af Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 11 Feb 2015 18:46:57 -0800 Subject: [PATCH 4/5] simple version --- .../scala/org/apache/spark/sql/DataFrame.scala | 17 ++++++++++++++++- .../org/apache/spark/sql/DataFrameImpl.scala | 15 +++++++++++++++ .../apache/spark/sql/IncomputableColumn.scala | 7 ++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f1322e59a5956..87a482304c508 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -461,7 +461,7 @@ trait DataFrame extends RDDApi[Row] { /** * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more - * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of + * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of * the input row are implicitly joined with each row that is output by the function. * * The following example uses this function to count the number of books which contain @@ -481,6 +481,21 @@ trait DataFrame extends RDDApi[Row] { */ def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame + + /** + * (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero + * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All + * columns of the input row are implicitly joined with each value that is output by the function. + * + * {{{ + * df.explode("words", "word")(words: String => words.split(" ")) + * }}} + */ + def explode[A, B : TypeTag]( + inputColumn: String, + outputColumn: String)( + f: A => TraversableOnce[B]): DataFrame + ///////////////////////////////////////////////////////////////////////////// /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index d3564f0839a1e..2650ef9e3f500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -303,6 +303,21 @@ private[sql] class DataFrameImpl protected[sql]( Generate(generator, join = true, outer = false, None, logicalPlan) } + override def explode[A, B : TypeTag]( + inputColumn: String, + outputColumn: String)( + f: A => TraversableOnce[B]): DataFrame = { + val dataType = ScalaReflection.schemaFor[B].dataType + val attributes = AttributeReference(outputColumn, dataType)() :: Nil + def rowFunction(row: Row) = { + f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType))) + } + val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) + + Generate(generator, join = true, outer = false, None, logicalPlan) + + } + ///////////////////////////////////////////////////////////////////////////// // RDD API diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala index 7e52031b380c6..88086a831431e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala @@ -118,7 +118,12 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten override def explode[A <: Product : TypeTag] (input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err() - ///////////////////////////////////////////////////////////////////////////// + override def explode[A, B : TypeTag]( + inputColumn: String, + outputColumn: String)( + f: A => TraversableOnce[B]): DataFrame = err() + + ///////////////////////////////////////////////////////////////////////////// override def head(n: Int): Array[Row] = err() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e8abef1224dfa..33b35f376b270 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -98,6 +98,15 @@ class DataFrameSuite extends QueryTest { sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) } + test("simple explode") { + val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDataFrame("words") + + checkAnswer( + df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word), + Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil + ) + } + test("explode") { val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters") val df2 = From eefd33a20db233a30e85e21997c571c2f5686c80 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 11 Feb 2015 23:28:56 -0800 Subject: [PATCH 5/5] whitespace --- sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index c1c1c11441991..bb5c6226a2217 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -308,7 +308,6 @@ private[sql] class DataFrameImpl protected[sql]( } - ///////////////////////////////////////////////////////////////////////////// // RDD API /////////////////////////////////////////////////////////////////////////////