From 52ca0dcdce3695890a8efe1d225b49d9cf9587d2 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 13 May 2015 04:55:29 +0000 Subject: [PATCH 1/7] [SPARK-7548][SQL] Add explode function for dataframes. --- .../sql/catalyst/analysis/Analyzer.scala | 112 +++++++++++------- .../sql/catalyst/analysis/AnalysisSuite.scala | 10 +- .../scala/org/apache/spark/sql/Column.scala | 13 +- .../org/apache/spark/sql/DataFrame.scala | 5 +- .../org/apache/spark/sql/functions.scala | 5 + .../spark/sql/ColumnExpressionSuite.scala | 51 ++++++++ 6 files changed, 148 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a4c61149dd975..483ad88f5b95c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -73,7 +73,6 @@ class Analyzer( ResolveGroupingAnalytics :: ResolveSortReferences :: ResolveGenerate :: - ImplicitGenerate :: ResolveFunctions :: ExtractWindowExpressions :: GlobalAggregates :: @@ -516,66 +515,89 @@ class Analyzer( } /** - * When a SELECT clause has only a single expression and that expression is a - * [[catalyst.expressions.Generator Generator]] we convert the - * [[catalyst.plans.logical.Project Project]] to a [[catalyst.plans.logical.Generate Generate]]. + * Rewrites table generating expressions that either need one or more of the following in order + * to be resolved: + * - concrete attribute references for their output. + * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a [[Generate]]). + * + * Names for the output [[Attributes]] are extracted from [[Alias]] or [[MultiAlias]] expressions + * that wrap the [[Generator]]. If more than one [[Generator]] is found in a Project, an + * [[AnalysisException]] is throw. */ - object ImplicitGenerate extends Rule[LogicalPlan] { + object ResolveGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, name)), child) => - Generate(g, join = false, outer = false, - qualifier = None, UnresolvedAttribute(name) :: Nil, child) - case Project(Seq(MultiAlias(g: Generator, names)), child) => - Generate(g, join = false, outer = false, - qualifier = None, names.map(UnresolvedAttribute(_)), child) + case p: Generate if !p.child.resolved || !p.generator.resolved => p + case g: Generate if g.resolved == false => + g.copy( + generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) + + case p @ Project(projectList, child) => + // Holds the resolved generator, if one exists in the project list. + var resolvedGenerator: Generate = null + + val newProjectList = projectList.flatMap { + case AliasedGenerator(generator, names) if generator.childrenResolved => + if (resolvedGenerator != null) { + failAnalysis( + s"Only one generator allowed per select but ${resolvedGenerator.nodeName} and " + + s"and ${generator.nodeName} found.") + } + + resolvedGenerator = + Generate( + generator, + join = projectList.size > 1, // Only join if there are other expressions in SELECT. + outer = false, + qualifier = None, + generatorOutput = makeGeneratorOutput(generator, names), + child) + + resolvedGenerator.generatorOutput + case other => other :: Nil + } + + if (resolvedGenerator != null) { + Project(newProjectList, resolvedGenerator) + } else { + p + } } - } - /** - * Resolve the Generate, if the output names specified, we will take them, otherwise - * we will try to provide the default names, which follow the same rule with Hive. - */ - object ResolveGenerate extends Rule[LogicalPlan] { - // Construct the output attributes for the generator, - // The output attribute names can be either specified or - // auto generated. + /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ + private object AliasedGenerator { + def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { + case Alias(g: Generator, name) => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) => Some(g, names) + case _ => None + } + } + + /** + * Construct the output attributes for a [[Generator]], given a list of names. If the list of + * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to match Hive's defaults. + */ private def makeGeneratorOutput( generator: Generator, - generatorOutput: Seq[Attribute]): Seq[Attribute] = { + names: Seq[String]): Seq[Attribute] = { val elementTypes = generator.elementTypes - if (generatorOutput.length == elementTypes.length) { - generatorOutput.zip(elementTypes).map { - case (a, (t, nullable)) if !a.resolved => - AttributeReference(a.name, t, nullable)() - case (a, _) => a + if (names.length == elementTypes.length) { + names.zip(elementTypes).map { + case (name, (t, nullable)) => + AttributeReference(name, t, nullable)() } - } else if (generatorOutput.length == 0) { + } else if (names.isEmpty) { elementTypes.zipWithIndex.map { // keep the default column names as Hive does _c0, _c1, _cN case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() } } else { - throw new AnalysisException( - s""" - |The number of aliases supplied in the AS clause does not match - |the number of columns output by the UDTF expected - |${elementTypes.size} aliases but got ${generatorOutput.size} - """.stripMargin) + failAnalysis( + "The number of aliases supplied in the AS clause does not match the number of columns " + + s"output by the UDTF expected ${elementTypes.size} aliases but got " + + s"${names.mkString(",")} ") } } - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p: Generate if !p.child.resolved || !p.generator.resolved => p - case p: Generate if p.resolved == false => - // if the generator output names are not specified, we will use the default ones. - Generate( - p.generator, - join = p.join, - outer = p.outer, - p.qualifier, - makeGeneratorOutput(p.generator, p.generatorOutput), p.child) - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6f2f35564d12e..e1d6ac462fbcc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -72,6 +72,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { StructField("cField", StringType) :: Nil ))()) + val listRelation = LocalRelation( + AttributeReference("list", ArrayType(IntegerType))()) + before { caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation) @@ -159,10 +162,15 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { } } - errorMessages.foreach(m => assert(error.getMessage contains m)) + errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase)) } } + errorTest( + "too many generators", + listRelation.select(Explode('list).as('a), Explode('list).as('b)), + "only one generator" :: "explode" :: Nil) + errorTest( "unresolved attributes", testRelation.select('abcd), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4d50821620f5e..d04b23428a01e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue} import org.apache.spark.sql.types._ @@ -615,6 +615,17 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def as(alias: String): Column = Alias(expr, alias)() + /** + * Assigns the given aliases to the results of a table generating function. + * {{{ + * // Renames colA to colB in select output. + * df.select(explode($"myMap").as("key" :: "value" :: Nil)) + * }}} + * + * @group expr_ops + */ + def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases) + /** * Gives the column an alias. * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 1f85dac682cbe..46ce387b20fab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -33,7 +33,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -565,6 +565,9 @@ class DataFrame private[sql]( def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { case Column(expr: NamedExpression) => expr + // Leave an unaliased explode with an empty list of names since the analzyer will generate the + // correct defaults after the nested expression's type has been resolved. + case Column(explode: Explode) => MultiAlias(explode, Nil) case Column(expr: Expression) => Alias(expr, expr.prettyString)() } // When user continuously call `select`, speed up analysis by collapsing `Project` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fae4bd0fd2994..14453110c6626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -329,6 +329,11 @@ object functions { @scala.annotation.varargs def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + /** + * Creates a new row for each element in the given array or map column. + */ + def explode(e: Column): Column = Explode(e.expr) + /** * Converts a string exprsesion to lower case. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index d96186c268720..f5ec388567444 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -27,6 +27,57 @@ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + test("single explode") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + checkAnswer( + df.select(explode('intList)), + Row(1) :: Row(2) :: Row(3) :: Nil) + } + + test("explode and other columns") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + + checkAnswer( + df.select($"a", explode('intList)), + Row(1, 1) :: + Row(1, 2) :: + Row(1, 3) :: Nil) + + checkAnswer( + df.select($"*", explode('intList)), + Row(1, Seq(1,2,3), 1) :: + Row(1, Seq(1,2,3), 2) :: + Row(1, Seq(1,2,3), 3) :: Nil) + } + + test("aliased explode") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + + checkAnswer( + df.select(explode('intList).as('int)).select('int), + Row(1) :: Row(2) :: Row(3) :: Nil) + + checkAnswer( + df.select(explode('intList).as('int)).select(sum('int)), + Row(6) :: Nil) + } + + test("explode on map") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map)), + Row("a", "b")) + } + + test("explode on map with aliases") { + val df = Seq((1, Map("a" -> "b"))).toDF("a", "map") + + checkAnswer( + df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), + Row("a", "b")) + } + test("collect on column produced by a binary operator") { val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) From e710fe472d5e07ec25047effe63ee558da31cef6 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 14 May 2015 00:29:58 +0000 Subject: [PATCH 2/7] add java and python --- python/pyspark/sql/dataframe.py | 8 ++++++-- python/pyspark/sql/functions.py | 3 +++ python/pyspark/sql/tests.py | 15 +++++++++++++++ .../scala/org/apache/spark/sql/Column.scala | 17 ++++++++++++++++- 4 files changed, 40 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 078acfdf7e2df..0adf17131abab 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1511,13 +1511,17 @@ def inSet(self, *cols): isNull = _unary_op("isNull", "True if the current expression is null.") isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") - def alias(self, alias): + def alias(self, *alias): """Return a alias for this column >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] """ - return Column(getattr(self._jc, "as")(alias)) + + if len(alias) == 1: + return Column(getattr(self._jc, "as")(alias[0])) + else: + return Column(getattr(self._jc, "as")(alias)) @ignore_unicode_prefix def cast(self, dataType): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 38a043a3c59d7..81241ac655cad 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -77,6 +77,9 @@ def _(col1, col2): 'sqrt': 'Computes the square root of the specified float value.', 'abs': 'Computes the absolute value.', + # table generating functions + 'explode': 'Returns a new row for each element in the given array or map.', + # unary math functions 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + '0.0 through pi.', diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1922d03af61da..624143d00bdeb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -117,6 +117,21 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_explode(self): + from pyspark.sql.functions import explode + d = [Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + + result = data.select(explode(data.intlist).alias("a")).select("a").collect() + self.assertEqual(result[0][0], 1) + self.assertEqual(result[1][0], 2) + self.assertEqual(result[2][0], 3) + + result = data.select(explode(data.mapfield).alias("a", "b")).select("a", "b").collect() + self.assertEqual(result[0][0], "a") + self.assertEqual(result[0][1], "b") + def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d04b23428a01e..2210ac9c29403 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.language.implicitConversions +import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging @@ -616,7 +617,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { def as(alias: String): Column = Alias(expr, alias)() /** - * Assigns the given aliases to the results of a table generating function. + * (Scala-specific) Assigns the given aliases to the results of a table generating function. * {{{ * // Renames colA to colB in select output. * df.select(explode($"myMap").as("key" :: "value" :: Nil)) @@ -626,6 +627,20 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases) + /** + * Assigns the given aliases to the results of a table generating function. + * {{{ + * // Renames colA to colB in select output. + * df.select(explode($"myMap").as("key" :: "value" :: Nil)) + * }}} + * + * @group expr_ops + */ + def as(aliases: Array[String]): Column = MultiAlias(expr, aliases) + + /** Used for multi aliases from python. */ + protected def as(aliases: java.util.List[String]): Column = MultiAlias(expr, aliases) + /** * Gives the column an alias. * {{{ From f9e1e3e91ccc6b54766c0b3bbb840816aad3f10b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 14 May 2015 02:25:33 +0000 Subject: [PATCH 3/7] fix python, add since --- python/pyspark/sql/dataframe.py | 3 ++- python/pyspark/sql/functions.py | 23 ++++++++++++++++--- python/pyspark/sql/tests.py | 2 +- .../scala/org/apache/spark/sql/Column.scala | 5 ++-- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index cc86e5177e188..7911dfe0538a8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1521,7 +1521,8 @@ def alias(self, *alias): if len(alias) == 1: return Column(getattr(self._jc, "as")(alias[0])) else: - return Column(getattr(self._jc, "as")(alias)) + sc = SparkContext._active_spark_context + return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias)))) @ignore_unicode_prefix def cast(self, dataType): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 29dac4b7d7f92..6cd6974b0e5bb 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -78,9 +78,6 @@ def _(col1, col2): 'sqrt': 'Computes the square root of the specified float value.', 'abs': 'Computes the absolute value.', - # table generating functions - 'explode': 'Returns a new row for each element in the given array or map.', - # unary math functions 'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' + '0.0 through pi.', @@ -172,6 +169,26 @@ def approxCountDistinct(col, rsd=None): return Column(jc) +def explode(col): + """Returns a new row for each element in the given array or map. + + >>> from pyspark.sql import Row + >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})]) + >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect() + [Row(anInt=1), Row(anInt=2), Row(anInt=3)] + + >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show() + +---+-----+ + |key|value| + +---+-----+ + | a| b| + +---+-----+ + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.explode(_to_java_column(col)) + return Column(jc) + + def coalesce(*cols): """Returns the first column that is not null. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 624143d00bdeb..d37c5dbed7f6b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -119,7 +119,7 @@ def tearDownClass(cls): def test_explode(self): from pyspark.sql.functions import explode - d = [Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})] + d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})] rdd = self.sc.parallelize(d) data = self.sqlCtx.createDataFrame(rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 211d49e3a8111..dc0aeea7c4aea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -736,6 +736,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * }}} * * @group expr_ops + * @since 1.4.0 */ def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases) @@ -747,12 +748,10 @@ class Column(protected[sql] val expr: Expression) extends Logging { * }}} * * @group expr_ops + * @since 1.4.0 */ def as(aliases: Array[String]): Column = MultiAlias(expr, aliases) - /** Used for multi aliases from python. */ - protected def as(aliases: java.util.List[String]): Column = MultiAlias(expr, aliases) - /** * Gives the column an alias. * {{{ From d3faa05b17f9106f415dfc61b445e18473a51e2a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 14 May 2015 06:24:20 +0000 Subject: [PATCH 4/7] fix self join case --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 +++++ .../org/apache/spark/sql/ColumnExpressionSuite.scala | 9 +++++++++ 2 files changed, 14 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 483ad88f5b95c..48daf02137f16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -321,6 +321,11 @@ class Analyzer( case oldVersion @ Aggregate(_, aggregateExpressions, _) if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + + case oldVersion: Generate + if AttributeSet(oldVersion.generatorOutput).intersect(conflictingAttributes).nonEmpty => + val newOutput = oldVersion.generatorOutput.map(_.newInstance()) + (oldVersion, oldVersion.copy(generatorOutput = newOutput)) }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. sys.error( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 314efa8f32c1a..9bdf201b3be7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -78,6 +78,15 @@ class ColumnExpressionSuite extends QueryTest { Row("a", "b")) } + test("self join explode") { + val df = Seq((1, Seq(1,2,3))).toDF("a", "intList") + val exploded = df.select(explode('intList).as('i)) + + checkAnswer( + exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")), + Row(3) :: Nil) + } + test("collect on column produced by a binary operator") { val df = Seq((1, 2, 3)).toDF("a", "b", "c") checkAnswer(df.select(df("a") + df("b")), Seq(Row(3))) From 81b5da361157dd159415d54b46caeb802173af73 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 14 May 2015 06:36:38 +0000 Subject: [PATCH 5/7] style --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 48daf02137f16..48ad17935f98e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -323,7 +323,7 @@ class Analyzer( (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) case oldVersion: Generate - if AttributeSet(oldVersion.generatorOutput).intersect(conflictingAttributes).nonEmpty => + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) (oldVersion, oldVersion.copy(generatorOutput = newOutput)) }.headOption.getOrElse { // Only handle first case, others will be fixed on the next pass. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 0f349f9d11415..01f4b6e9bb77d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -59,6 +59,9 @@ case class Generate( child: LogicalPlan) extends UnaryNode { + /** The set of all attributes produced by this node. */ + def generatedSet: AttributeSet = AttributeSet(generatorOutput) + override lazy val resolved: Boolean = { generator.resolved && childrenResolved && From 6f80ba36474eeb868f2912c3cb4f8412e3076882 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 14 May 2015 16:12:42 -0700 Subject: [PATCH 6/7] Update dataframe.py --- python/pyspark/sql/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7911dfe0538a8..f9a0aa2afaff8 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1512,7 +1512,8 @@ def inSet(self, *cols): isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") def alias(self, *alias): - """Return a alias for this column + """Returns this column aliased with a new name or names (in the case of expressions that + return more than one column, such as explode). >>> df.select(df.age.alias("age2")).collect() [Row(age2=2), Row(age2=5)] From 7ee2c8706db07c6e8f652e10fbfff66de324758b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 14 May 2015 17:09:15 -0700 Subject: [PATCH 7/7] whitespace --- python/pyspark/sql/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index f9a0aa2afaff8..2ed95ac8e2505 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1512,7 +1512,7 @@ def inSet(self, *cols): isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") def alias(self, *alias): - """Returns this column aliased with a new name or names (in the case of expressions that + """Returns this column aliased with a new name or names (in the case of expressions that return more than one column, such as explode). >>> df.select(df.age.alias("age2")).collect()