Skip to content

Commit

Permalink
[SPARK-6957] [SPARK-6958] [SQL] improve API compatibility to pandas
Browse files Browse the repository at this point in the history
```
select(['cola', 'colb'])

groupby(['colA', 'colB'])
groupby([df.colA, df.colB])

df.sort('A', ascending=True)
df.sort(['A', 'B'], ascending=True)
df.sort(['A', 'B'], ascending=[1, 0])
```

cc rxin

Author: Davies Liu <davies@databricks.com>

Closes #5544 from davies/compatibility and squashes the following commits:

4944058 [Davies Liu] add docstrings
adb2816 [Davies Liu] Merge branch 'master' of github.com:apache/spark into compatibility
bcbbcab [Davies Liu] support ascending as list
8dabdf0 [Davies Liu] improve API compatibility to pandas
  • Loading branch information
Davies Liu authored and rxin committed Apr 17, 2015
1 parent dc48ba9 commit c84d916
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 39 deletions.
96 changes: 66 additions & 30 deletions python/pyspark/sql/dataframe.py
Expand Up @@ -485,30 +485,60 @@ def join(self, other, joinExprs=None, joinType=None):
return DataFrame(jdf, self.sql_ctx)

@ignore_unicode_prefix
def sort(self, *cols):
def sort(self, *cols, **kwargs):
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
:param cols: list of :class:`Column` to sort by.
:param cols: list of :class:`Column` or column names to sort by.
:param ascending: sort by ascending order or not, could be bool, int
or list of bool, int (default: True).
>>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.sort("age", ascending=False).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.orderBy(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> from pyspark.sql.functions import *
>>> df.sort(asc("age")).collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.orderBy(desc("age"), "name").collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
>>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
"""
if not cols:
raise ValueError("should sort by at least one column")
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
jcols = [_to_java_column(c) for c in cols]
ascending = kwargs.get('ascending', True)
if isinstance(ascending, (bool, int)):
if not ascending:
jcols = [jc.desc() for jc in jcols]
elif isinstance(ascending, list):
jcols = [jc if asc else jc.desc()
for asc, jc in zip(ascending, jcols)]
else:
raise TypeError("ascending can only be bool or list, but got %s" % type(ascending))

jdf = self._jdf.sort(self._jseq(jcols))
return DataFrame(jdf, self.sql_ctx)

orderBy = sort

def _jseq(self, cols, converter=None):
"""Return a JVM Seq of Columns from a list of Column or names"""
return _to_seq(self.sql_ctx._sc, cols, converter)

def _jcols(self, *cols):
"""Return a JVM Seq of Columns from a list of Column or column names
If `cols` has only one list in it, cols[0] will be used as the list.
"""
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]
return self._jseq(cols, _to_java_column)

def describe(self, *cols):
"""Computes statistics for numeric columns.
Expand All @@ -523,9 +553,7 @@ def describe(self, *cols):
min 2
max 5
"""
cols = ListConverter().convert(cols,
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
jdf = self._jdf.describe(self._jseq(cols))
return DataFrame(jdf, self.sql_ctx)

@ignore_unicode_prefix
Expand Down Expand Up @@ -607,9 +635,7 @@ def select(self, *cols):
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
jdf = self._jdf.select(self._jcols(*cols))
return DataFrame(jdf, self.sql_ctx)

def selectExpr(self, *expr):
Expand All @@ -620,8 +646,9 @@ def selectExpr(self, *expr):
>>> df.selectExpr("age * 2", "abs(age)").collect()
[Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
"""
jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
if len(expr) == 1 and isinstance(expr[0], list):
expr = expr[0]
jdf = self._jdf.selectExpr(self._jseq(expr))
return DataFrame(jdf, self.sql_ctx)

@ignore_unicode_prefix
Expand Down Expand Up @@ -659,6 +686,8 @@ def groupBy(self, *cols):
so we can run aggregation on them. See :class:`GroupedData`
for all the available aggregate functions.
:func:`groupby` is an alias for :func:`groupBy`.
:param cols: list of columns to group by.
Each element should be a column name (string) or an expression (:class:`Column`).
Expand All @@ -668,12 +697,14 @@ def groupBy(self, *cols):
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(df.name).avg().collect()
[Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
>>> df.groupBy(['name', df.age]).count().collect()
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
"""
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
jdf = self._jdf.groupBy(self._jcols(*cols))
return GroupedData(jdf, self.sql_ctx)

groupby = groupBy

def agg(self, *exprs):
""" Aggregate on the entire :class:`DataFrame` without groups
(shorthand for ``df.groupBy.agg()``).
Expand Down Expand Up @@ -744,9 +775,7 @@ def dropna(self, how='any', thresh=None, subset=None):
if thresh is None:
thresh = len(subset) if how == 'any' else 1

cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx)
return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), self.sql_ctx)

def fillna(self, value, subset=None):
"""Replace null values, alias for ``na.fill()``.
Expand Down Expand Up @@ -799,9 +828,7 @@ def fillna(self, value, subset=None):
elif not isinstance(subset, (list, tuple)):
raise ValueError("subset should be a list or tuple of column names")

cols = ListConverter().convert(subset, self.sql_ctx._sc._gateway._gateway_client)
cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)

@ignore_unicode_prefix
def withColumn(self, colName, col):
Expand Down Expand Up @@ -862,10 +889,8 @@ def _api(self):

def df_varargs_api(f):
def _api(self, *args):
jargs = ListConverter().convert(args,
self.sql_ctx._sc._gateway._gateway_client)
name = f.__name__
jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
return DataFrame(jdf, self.sql_ctx)
_api.__name__ = f.__name__
_api.__doc__ = f.__doc__
Expand Down Expand Up @@ -912,9 +937,8 @@ def agg(self, *exprs):
else:
# Columns
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
self.sql_ctx._sc._gateway._gateway_client)
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
jdf = self._jdf.agg(exprs[0]._jc,
_to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
return DataFrame(jdf, self.sql_ctx)

@dfapi
Expand Down Expand Up @@ -1006,6 +1030,19 @@ def _to_java_column(col):
return jcol


def _to_seq(sc, cols, converter=None):
"""
Convert a list of Column (or names) into a JVM Seq of Column.
An optional `converter` could be used to convert items in `cols`
into JVM Column objects.
"""
if converter:
cols = [converter(c) for c in cols]
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
return sc._jvm.PythonUtils.toSeq(jcols)


def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
Expand Down Expand Up @@ -1177,8 +1214,7 @@ def inSet(self, *cols):
cols = cols[0]
cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
sc = SparkContext._active_spark_context
jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols))
jc = getattr(self._jc, "in")(_to_seq(sc, cols))
return Column(jc)

# order
Expand Down
11 changes: 3 additions & 8 deletions python/pyspark/sql/functions.py
Expand Up @@ -23,13 +23,11 @@
if sys.version < "3":
from itertools import imap as map

from py4j.java_collections import ListConverter

from pyspark import SparkContext
from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.dataframe import Column, _to_java_column
from pyspark.sql.dataframe import Column, _to_java_column, _to_seq


__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
Expand Down Expand Up @@ -87,8 +85,7 @@ def countDistinct(col, *cols):
[Row(c=2)]
"""
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client)
jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols))
jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column))
return Column(jc)


Expand Down Expand Up @@ -138,9 +135,7 @@ def __del__(self):

def __call__(self, *cols):
sc = SparkContext._active_spark_context
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
sc._gateway._gateway_client)
jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
return Column(jc)


Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests.py
Expand Up @@ -282,7 +282,7 @@ def test_apply_schema(self):
StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
StructField("list1", ArrayType(ByteType(), False), False),
StructField("null1", DoubleType(), True)])
df = self.sqlCtx.applySchema(rdd, schema)
df = self.sqlCtx.createDataFrame(rdd, schema)
results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
Expand Down

0 comments on commit c84d916

Please sign in to comment.