Skip to content

Commit

Permalink
[SPARK-6210] [SQL] use prettyString as column name in agg()
Browse files Browse the repository at this point in the history
use prettyString instead of toString() (which include id of expression) as column name in agg()

Author: Davies Liu <davies@databricks.com>

Closes #5006 from davies/prettystring and squashes the following commits:

cb1fdcf [Davies Liu] use prettyString as column name in agg()

(cherry picked from commit b38e073)
Signed-off-by: Reynold Xin <rxin@databricks.com>
  • Loading branch information
Davies Liu authored and rxin committed Mar 14, 2015
1 parent 3012781 commit ad47563
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
32 changes: 16 additions & 16 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,11 +631,11 @@ def groupBy(self, *cols):
for all the available aggregate functions.
>>> df.groupBy().avg().collect()
[Row(AVG(age#0)=3.5)]
[Row(AVG(age)=3.5)]
>>> df.groupBy('name').agg({'age': 'mean'}).collect()
[Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
[Row(name=u'Bob', AVG(age)=5.0), Row(name=u'Alice', AVG(age)=2.0)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
Expand All @@ -647,10 +647,10 @@ def agg(self, *exprs):
(shorthand for df.groupBy.agg()).
>>> df.agg({"age": "max"}).collect()
[Row(MAX(age#0)=5)]
[Row(MAX(age)=5)]
>>> from pyspark.sql import functions as F
>>> df.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=2)]
[Row(MIN(age)=2)]
"""
return self.groupBy().agg(*exprs)

Expand Down Expand Up @@ -766,7 +766,7 @@ def agg(self, *exprs):
>>> from pyspark.sql import functions as F
>>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
[Row(MIN(age)=5), Row(MIN(age)=2)]
"""
assert exprs, "exprs should not be empty"
if len(exprs) == 1 and isinstance(exprs[0], dict):
Expand Down Expand Up @@ -795,9 +795,9 @@ def mean(self, *cols):
for each group. This is an alias for `avg`.
>>> df.groupBy().mean('age').collect()
[Row(AVG(age#0)=3.5)]
[Row(AVG(age)=3.5)]
>>> df3.groupBy().mean('age', 'height').collect()
[Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
[Row(AVG(age)=3.5, AVG(height)=82.5)]
"""

@df_varargs_api
Expand All @@ -806,9 +806,9 @@ def avg(self, *cols):
for each group.
>>> df.groupBy().avg('age').collect()
[Row(AVG(age#0)=3.5)]
[Row(AVG(age)=3.5)]
>>> df3.groupBy().avg('age', 'height').collect()
[Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
[Row(AVG(age)=3.5, AVG(height)=82.5)]
"""

@df_varargs_api
Expand All @@ -817,9 +817,9 @@ def max(self, *cols):
each group.
>>> df.groupBy().max('age').collect()
[Row(MAX(age#0)=5)]
[Row(MAX(age)=5)]
>>> df3.groupBy().max('age', 'height').collect()
[Row(MAX(age#4L)=5, MAX(height#5L)=85)]
[Row(MAX(age)=5, MAX(height)=85)]
"""

@df_varargs_api
Expand All @@ -828,9 +828,9 @@ def min(self, *cols):
each group.
>>> df.groupBy().min('age').collect()
[Row(MIN(age#0)=2)]
[Row(MIN(age)=2)]
>>> df3.groupBy().min('age', 'height').collect()
[Row(MIN(age#4L)=2, MIN(height#5L)=80)]
[Row(MIN(age)=2, MIN(height)=80)]
"""

@df_varargs_api
Expand All @@ -839,9 +839,9 @@ def sum(self, *cols):
group.
>>> df.groupBy().sum('age').collect()
[Row(SUM(age#0)=7)]
[Row(SUM(age)=7)]
>>> df3.groupBy().sum('age', 'height').collect()
[Row(SUM(age#4L)=7, SUM(height#5L)=165)]
[Row(SUM(age)=7, SUM(height)=165)]
"""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
}.mkString(", ")

/** String representation of this node without any children */
def simpleString = s"$nodeName $argString"
def simpleString = s"$nodeName $argString".trim

override def toString: String = treeString

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
val namedGroupingExprs = groupingExprs.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.toString)()
case expr: Expression => Alias(expr, expr.prettyString)()
}
DataFrame(
df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
Expand All @@ -64,7 +64,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
}
columnExprs.map { c =>
val a = f(c)
Alias(a, a.toString)()
Alias(a, a.prettyString)()
}
}

Expand Down Expand Up @@ -116,7 +116,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
def agg(exprs: Map[String, String]): DataFrame = {
exprs.map { case (colName, expr) =>
val a = strToExpr(expr)(df(colName).expr)
Alias(a, a.toString)()
Alias(a, a.prettyString)()
}.toSeq
}

Expand Down Expand Up @@ -160,7 +160,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
def agg(expr: Column, exprs: Column*): DataFrame = {
val aggExprs = (expr +: exprs).map(_.expr).map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.toString)()
case expr: Expression => Alias(expr, expr.prettyString)()
}
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
}
Expand Down

0 comments on commit ad47563

Please sign in to comment.