From 371a3f722c2c4c92275c30bfe060b353c719a216 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 14 Feb 2015 00:15:33 +0800 Subject: [PATCH 1/9] Compute aggregation function on specified numeric columns. --- .../org/apache/spark/sql/DataFrameImpl.scala | 9 +++- .../org/apache/spark/sql/GroupedData.scala | 52 ++++++++++++++++++- .../org/apache/spark/sql/DataFrameSuite.scala | 12 +++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index bb5c6226a221..001128bc3805 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -93,7 +93,14 @@ private[sql] class DataFrameImpl protected[sql]( queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get } } - + + protected[sql] def numericColumns(colNames: String*): Seq[Expression] = { + schema.fields.filter(n => colNames.contains(n.name) && n.dataType.isInstanceOf[NumericType]) + .map { n => + queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + } + } + override def toDataFrame(colNames: String*): DataFrame = { require(schema.size == colNames.size, "The number of columns doesn't match.\n" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 3c20676355c9..2dcd76288613 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -45,7 +45,15 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio Alias(a, a.toString)() } } - + + private[this] def aggregateNumericColumns(colName: String, colNames: String*) + (f: Expression => Expression): Seq[NamedExpression] = { + df.numericColumns((Seq(colName) ++ colNames):_*).map { c => + val a = f(c) + Alias(a, a.toString)() + } + } + private[this] def strToExpr(expr: String): (Expression => Expression) = { expr.toLowerCase match { case "avg" | "average" | "mean" => Average @@ -149,28 +157,70 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio * The resulting [[DataFrame]] will also contain the grouping columns. */ def mean(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the average value for given numeric columns for each group. This is an alias for `avg`. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + def mean(colName: String, colNames: String*): DataFrame = { + aggregateNumericColumns(colName, colNames:_*)(Average) + } /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ def max(): DataFrame = aggregateNumericColumns(Max) + + /** + * Compute the max value for given numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + def max(colName: String, colNames: String*): DataFrame = { + aggregateNumericColumns(colName, colNames:_*)(Max) + } /** * Compute the mean value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ def avg(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the mean value for given numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + def avg(colName: String, colNames: String*): DataFrame = { + aggregateNumericColumns(colName, colNames:_*)(Average) + } /** * Compute the min value for each numeric column for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ def min(): DataFrame = aggregateNumericColumns(Min) + + /** + * Compute the min value for given numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + def min(colName: String, colNames: String*): DataFrame = { + aggregateNumericColumns(colName, colNames:_*)(Min) + } /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. */ def sum(): DataFrame = aggregateNumericColumns(Sum) + + /** + * Compute the sum for given numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + def sum(colName: String, colNames: String*): DataFrame = { + aggregateNumericColumns(colName, colNames:_*)(Sum) + } + + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 33b35f376b27..53e4ec09dd7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -154,6 +154,18 @@ class DataFrameSuite extends QueryTest { testData2.agg(sum('b)), Row(9) ) + + val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")) + .toDataFrame("key", "value1", "value2", "rest") + + checkAnswer( + df1.groupBy("key").min(), + df1.groupBy("key").min("value1", "value2").collect + ) + checkAnswer( + df1.groupBy("key").min("value2"), + Seq(Row("a",0), Row("b",4)) + ) } test("convert $\"attribute name\" into unresolved attribute") { From 27069c39b84ae6312efd80b7a2965c45a17348da Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 14 Feb 2015 14:33:52 +0800 Subject: [PATCH 2/9] Combine functions and add varargs annotation. --- .../org/apache/spark/sql/GroupedData.scala | 89 +++++++++---------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 2dcd76288613..ba80cf30f236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -45,7 +45,8 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio Alias(a, a.toString)() } } - + + @scala.annotation.varargs private[this] def aggregateNumericColumns(colName: String, colNames: String*) (f: Expression => Expression): Seq[NamedExpression] = { df.numericColumns((Seq(colName) ++ colNames):_*).map { c => @@ -155,72 +156,70 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the average values for them. */ - def mean(): DataFrame = aggregateNumericColumns(Average) - - /** - * Compute the average value for given numeric columns for each group. This is an alias for `avg`. - * The resulting [[DataFrame]] will also contain the grouping columns. - */ - def mean(colName: String, colNames: String*): DataFrame = { - aggregateNumericColumns(colName, colNames:_*)(Average) + @scala.annotation.varargs + def mean(colNames: String*): DataFrame = { + if (colNames.isEmpty) { + aggregateNumericColumns(Average) + } else { + aggregateNumericColumns(colNames.head, colNames.tail:_*)(Average) + } } - - /** - * Compute the max value for each numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - */ - def max(): DataFrame = aggregateNumericColumns(Max) /** - * Compute the max value for given numeric columns for each group. + * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the max values for them. */ - def max(colName: String, colNames: String*): DataFrame = { - aggregateNumericColumns(colName, colNames:_*)(Max) + @scala.annotation.varargs + def max(colNames: String*): DataFrame = { + if (colNames.isEmpty) { + aggregateNumericColumns(Max) + } else { + aggregateNumericColumns(colNames.head, colNames.tail:_*)(Max) + } } /** * Compute the mean value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the mean values for them. */ - def avg(): DataFrame = aggregateNumericColumns(Average) - - /** - * Compute the mean value for given numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - */ - def avg(colName: String, colNames: String*): DataFrame = { - aggregateNumericColumns(colName, colNames:_*)(Average) + @scala.annotation.varargs + def avg(colNames: String*): DataFrame = { + if (colNames.isEmpty) { + aggregateNumericColumns(Average) + } else { + aggregateNumericColumns(colNames.head, colNames.tail:_*)(Average) + } } /** * Compute the min value for each numeric column for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the min values for them. */ - def min(): DataFrame = aggregateNumericColumns(Min) - - /** - * Compute the min value for given numeric column for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - */ - def min(colName: String, colNames: String*): DataFrame = { - aggregateNumericColumns(colName, colNames:_*)(Min) + @scala.annotation.varargs + def min(colNames: String*): DataFrame = { + if (colNames.isEmpty) { + aggregateNumericColumns(Min) + } else { + aggregateNumericColumns(colNames.head, colNames.tail:_*)(Min) + } } /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. */ - def sum(): DataFrame = aggregateNumericColumns(Sum) - - /** - * Compute the sum for given numeric columns for each group. - * The resulting [[DataFrame]] will also contain the grouping columns. - */ - def sum(colName: String, colNames: String*): DataFrame = { - aggregateNumericColumns(colName, colNames:_*)(Sum) - } - - + @scala.annotation.varargs + def sum(colNames: String*): DataFrame = { + if (colNames.isEmpty) { + aggregateNumericColumns(Sum) + } else { + aggregateNumericColumns(colNames.head, colNames.tail:_*)(Sum) + } + } } From b1a24fc0275b895376e3de0af1cd7b3e2c72048b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 14 Feb 2015 19:34:16 +0800 Subject: [PATCH 3/9] Address comments. --- .../org/apache/spark/sql/DataFrameImpl.scala | 25 ++++++----- .../org/apache/spark/sql/GroupedData.scala | 42 ++++--------------- .../org/apache/spark/sql/DataFrameSuite.scala | 20 ++++----- 3 files changed, 32 insertions(+), 55 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 55d3509b5c79..407c792fb58a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -88,19 +88,24 @@ private[sql] class DataFrameImpl protected[sql]( } } - protected[sql] def numericColumns: Seq[Expression] = { - schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => - queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + protected[sql] def numericColumns(colNames: String*): Seq[Expression] = { + val allNumbericCols = schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map(_.name) + val diff = colNames.diff(allNumbericCols) + if (diff.nonEmpty) { + val diffStr = diff.mkString(", ") + throw new RuntimeException( + s"""Cannot resolve column names "($diffStr)" among (${schema.fieldNames.mkString(", ")})""") + } + val colsToResolve: Seq[String] = if (colNames.isEmpty) { + allNumbericCols + } else { + colNames + } + colsToResolve.map { n => + queryExecution.analyzed.resolve(n, sqlContext.analyzer.resolver).get } } - protected[sql] def numericColumns(colNames: String*): Seq[Expression] = { - schema.fields.filter(n => colNames.contains(n.name) && n.dataType.isInstanceOf[NumericType]) - .map { n => - queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get - } - } - override def toDF(colNames: String*): DataFrame = { require(schema.size == colNames.size, "The number of columns doesn't match.\n" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 13b0a5335efe..a420094022ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -39,17 +39,9 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) } - private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = { - df.numericColumns.map { c => - val a = f(c) - Alias(a, a.toString)() - } - } - - @scala.annotation.varargs - private[this] def aggregateNumericColumns(colName: String, colNames: String*) + private[this] def aggregateNumericColumns(colNames: String*) (f: Expression => Expression): Seq[NamedExpression] = { - df.numericColumns((Seq(colName) ++ colNames):_*).map { c => + df.numericColumns(colNames:_*).map { c => val a = f(c) Alias(a, a.toString)() } @@ -165,11 +157,7 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - if (colNames.isEmpty) { - aggregateNumericColumns(Average) - } else { - aggregateNumericColumns(colNames.head, colNames.tail:_*)(Average) - } + aggregateNumericColumns(colNames:_*)(Average) } /** @@ -179,11 +167,7 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio */ @scala.annotation.varargs def max(colNames: String*): DataFrame = { - if (colNames.isEmpty) { - aggregateNumericColumns(Max) - } else { - aggregateNumericColumns(colNames.head, colNames.tail:_*)(Max) - } + aggregateNumericColumns(colNames:_*)(Max) } /** @@ -193,11 +177,7 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - if (colNames.isEmpty) { - aggregateNumericColumns(Average) - } else { - aggregateNumericColumns(colNames.head, colNames.tail:_*)(Average) - } + aggregateNumericColumns(colNames:_*)(Average) } /** @@ -207,11 +187,7 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio */ @scala.annotation.varargs def min(colNames: String*): DataFrame = { - if (colNames.isEmpty) { - aggregateNumericColumns(Min) - } else { - aggregateNumericColumns(colNames.head, colNames.tail:_*)(Min) - } + aggregateNumericColumns(colNames:_*)(Min) } /** @@ -221,10 +197,6 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - if (colNames.isEmpty) { - aggregateNumericColumns(Sum) - } else { - aggregateNumericColumns(colNames.head, colNames.tail:_*)(Sum) - } + aggregateNumericColumns(colNames:_*)(Sum) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index ff22484b55fb..524571d9cc63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -162,24 +162,24 @@ class DataFrameSuite extends QueryTest { testData2.groupBy("a").agg(Map("b" -> "sum")), Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil ) - } - - test("agg without groups") { - checkAnswer( - testData2.agg(sum('b)), - Row(9) - ) val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")) - .toDataFrame("key", "value1", "value2", "rest") + .toDF("key", "value1", "value2", "rest") checkAnswer( df1.groupBy("key").min(), - df1.groupBy("key").min("value1", "value2").collect + df1.groupBy("key").min("value1", "value2").collect() ) checkAnswer( df1.groupBy("key").min("value2"), - Seq(Row("a",0), Row("b",4)) + Seq(Row("a", 0), Row("b", 4)) + ) + } + + test("agg without groups") { + checkAnswer( + testData2.agg(sum('b)), + Row(9) ) } From 4c63a01695ca68eeb003b642d8553c5916868012 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 15 Feb 2015 00:06:11 +0800 Subject: [PATCH 4/9] Fix pyspark. --- python/pyspark/sql/dataframe.py | 42 ++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1438fe5285cc..a45e7de00cd7 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -714,31 +714,45 @@ def count(self): [Row(age=2, count=1), Row(age=5, count=1)] """ - @dfapi - def mean(self): + def mean(self, *cols): """Compute the average value for each numeric columns for each group. This is an alias for `avg`.""" - - @dfapi - def avg(self): + jcols = ListConverter().convert(list(cols), + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.mean(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def avg(self, *cols): """Compute the average value for each numeric columns for each group.""" + jcols = ListConverter().convert(list(cols), + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.avg(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) - @dfapi - def max(self): + def max(self, *cols): """Compute the max value for each numeric columns for each group. """ - - @dfapi - def min(self): + jcols = ListConverter().convert(list(cols), + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.max(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def min(self, *cols): """Compute the min value for each numeric column for each group.""" - - @dfapi - def sum(self): + jcols = ListConverter().convert(list(cols), + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.min(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def sum(self, *cols): """Compute the sum for each numeric columns for each group.""" - + jcols = ListConverter().convert(list(cols), + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.sum(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) def _create_column_from_literal(literal): sc = SparkContext._active_spark_context From 880c2acdc6beb9a73515371e4356bc37329b3693 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 15 Feb 2015 01:14:17 +0800 Subject: [PATCH 5/9] Fix Python style checks. --- python/pyspark/sql/dataframe.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a45e7de00cd7..6e5ed5d2a49e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -721,7 +721,7 @@ def mean(self, *cols): self.sql_ctx._sc._gateway._gateway_client) jdf = self._jdf.mean(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) - + def avg(self, *cols): """Compute the average value for each numeric columns for each group.""" @@ -737,7 +737,7 @@ def max(self, *cols): self.sql_ctx._sc._gateway._gateway_client) jdf = self._jdf.max(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) - + def min(self, *cols): """Compute the min value for each numeric column for each group.""" @@ -745,7 +745,7 @@ def min(self, *cols): self.sql_ctx._sc._gateway._gateway_client) jdf = self._jdf.min(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) - + def sum(self, *cols): """Compute the sum for each numeric columns for each group.""" @@ -754,6 +754,7 @@ def sum(self, *cols): jdf = self._jdf.sum(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) + def _create_column_from_literal(literal): sc = SparkContext._active_spark_context return sc._jvm.functions.lit(literal) From b079e6bfd0efcc947f23283e66b5692dffcdd5bd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 15 Feb 2015 12:50:12 +0800 Subject: [PATCH 6/9] Remove duplicate codes. --- python/pyspark/sql/dataframe.py | 37 +++++++++++++++------------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6e5ed5d2a49e..ea740718f96d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -664,6 +664,18 @@ def _api(self): return _api +def df_varargs_api(f, *args): + def _api(self): + jargs = ListConverter().convert(list(args), + self.sql_ctx._sc._gateway._gateway_client) + name = f.__name__ + jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs)) + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + class GroupedData(object): """ @@ -714,45 +726,30 @@ def count(self): [Row(age=2, count=1), Row(age=5, count=1)] """ + @df_varargs_api def mean(self, *cols): """Compute the average value for each numeric columns for each group. This is an alias for `avg`.""" - jcols = ListConverter().convert(list(cols), - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.mean(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) + @df_varargs_api def avg(self, *cols): """Compute the average value for each numeric columns for each group.""" - jcols = ListConverter().convert(list(cols), - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.avg(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) + @df_varargs_api def max(self, *cols): """Compute the max value for each numeric columns for each group. """ - jcols = ListConverter().convert(list(cols), - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.max(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) + @df_varargs_api def min(self, *cols): """Compute the min value for each numeric column for each group.""" - jcols = ListConverter().convert(list(cols), - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.min(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) + @df_varargs_api def sum(self, *cols): """Compute the sum for each numeric columns for each group.""" - jcols = ListConverter().convert(list(cols), - self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.sum(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) - return DataFrame(jdf, self.sql_ctx) def _create_column_from_literal(literal): From 54ed0c43b1850c349a59a702af325b4dc5946971 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 15 Feb 2015 16:28:32 +0800 Subject: [PATCH 7/9] Address comments. --- python/pyspark/sql/dataframe.py | 48 +++++++++++++++---- .../org/apache/spark/sql/DataFrameImpl.scala | 13 +++-- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ea740718f96d..28a59e73a341 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -664,9 +664,9 @@ def _api(self): return _api -def df_varargs_api(f, *args): - def _api(self): - jargs = ListConverter().convert(list(args), +def df_varargs_api(f): + def _api(self, *args): + jargs = ListConverter().convert(args, self.sql_ctx._sc._gateway._gateway_client) name = f.__name__ jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs)) @@ -729,27 +729,57 @@ def count(self): @df_varargs_api def mean(self, *cols): """Compute the average value for each numeric columns - for each group. This is an alias for `avg`.""" + for each group. This is an alias for `avg`. + + >>> df.groupBy().mean('age').collect() + [Row(AVG(age#0)=3.5)] + >>> df3.groupBy().mean('age', 'height').collect() + [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)] + """ @df_varargs_api def avg(self, *cols): """Compute the average value for each numeric columns - for each group.""" + for each group. + + >>> df.groupBy().avg('age').collect() + [Row(AVG(age#0)=3.5)] + >>> df3.groupBy().avg('age', 'height').collect() + [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)] + """ @df_varargs_api def max(self, *cols): """Compute the max value for each numeric columns for - each group. """ + each group. + + >>> df.groupBy().max('age').collect() + [Row(MAX(age#0)=5)] + >>> df3.groupBy().max('age', 'height').collect() + [Row(MAX(age#4)=5, MAX(height#5)=85)] + """ @df_varargs_api def min(self, *cols): """Compute the min value for each numeric column for - each group.""" + each group. + + >>> df.groupBy().min('age').collect() + [Row(MIN(age#0)=2)] + >>> df3.groupBy().min('age', 'height').collect() + [Row(MIN(age#4)=2, MIN(height#5)=80)] + """ @df_varargs_api def sum(self, *cols): """Compute the sum for each numeric columns for each - group.""" + group. + + >>> df.groupBy().sum('age').collect() + [Row(SUM(age#0)=7)] + >>> df3.groupBy().sum('age', 'height').collect() + [Row(SUM(age#4)=7, SUM(height#5)=165)] + """ def _create_column_from_literal(literal): @@ -957,6 +987,8 @@ def _test(): globs['sqlCtx'] = SQLContext(sc) globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 407c792fb58a..7ce8aec9499b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -90,19 +90,18 @@ private[sql] class DataFrameImpl protected[sql]( protected[sql] def numericColumns(colNames: String*): Seq[Expression] = { val allNumbericCols = schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map(_.name) - val diff = colNames.diff(allNumbericCols) - if (diff.nonEmpty) { - val diffStr = diff.mkString(", ") - throw new RuntimeException( - s"""Cannot resolve column names "($diffStr)" among (${schema.fieldNames.mkString(", ")})""") - } val colsToResolve: Seq[String] = if (colNames.isEmpty) { allNumbericCols } else { colNames } colsToResolve.map { n => - queryExecution.analyzed.resolve(n, sqlContext.analyzer.resolver).get + if (colNames.isEmpty || allNumbericCols.contains(n)) { + queryExecution.analyzed.resolve(n, sqlContext.analyzer.resolver).get + } else { + throw new RuntimeException( + s"""Cannot resolve column name "($n)" among (${schema.fieldNames.mkString(", ")})""") + } } } From 353fad714409fb96a99455825fd6b44c56de43ee Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 15 Feb 2015 18:30:46 +0800 Subject: [PATCH 8/9] For python unit tests. --- python/pyspark/sql/functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 39aa550eeb5a..d0e090607ff4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -158,6 +158,8 @@ def _test(): globs['sqlCtx'] = SQLContext(sc) globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF() globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF() + globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80), + Row(name='Bob', age=5, height=85)]).toDF() (failure_count, test_count) = doctest.testmod( pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) From 94468966d5595bf97b545c0a127a787b23566fd0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 16 Feb 2015 17:23:21 +0800 Subject: [PATCH 9/9] For comments. --- .../org/apache/spark/sql/DataFrameImpl.scala | 17 ++--------- .../org/apache/spark/sql/GroupedData.scala | 28 +++++++++++++++---- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 7ce8aec9499b..9eb0c131405d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -88,20 +88,9 @@ private[sql] class DataFrameImpl protected[sql]( } } - protected[sql] def numericColumns(colNames: String*): Seq[Expression] = { - val allNumbericCols = schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map(_.name) - val colsToResolve: Seq[String] = if (colNames.isEmpty) { - allNumbericCols - } else { - colNames - } - colsToResolve.map { n => - if (colNames.isEmpty || allNumbericCols.contains(n)) { - queryExecution.analyzed.resolve(n, sqlContext.analyzer.resolver).get - } else { - throw new RuntimeException( - s"""Cannot resolve column name "($n)" among (${schema.fieldNames.mkString(", ")})""") - } + protected[sql] def numericColumns(): Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index a420094022ee..a5a677b68863 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -23,6 +23,8 @@ import scala.collection.JavaConversions._ import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.types.NumericType + /** @@ -39,12 +41,28 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) } - private[this] def aggregateNumericColumns(colNames: String*) - (f: Expression => Expression): Seq[NamedExpression] = { - df.numericColumns(colNames:_*).map { c => - val a = f(c) - Alias(a, a.toString)() + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) + : Seq[NamedExpression] = { + + val columnExprs = if (colNames.isEmpty) { + // No columns specified. Use all numeric columns. + df.numericColumns + } else { + // Make sure all specified columns are numeric + colNames.map { colName => + val namedExpr = df.resolve(colName) + if (!namedExpr.dataType.isInstanceOf[NumericType]) { + throw new AnalysisException( + s""""$colName" is not a numeric column. """ + + "Aggregation function can only be performed on a numeric column.") + } + namedExpr } + } + columnExprs.map { c => + val a = f(c) + Alias(a, a.toString)() + } } private[this] def strToExpr(expr: String): (Expression => Expression) = {