Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,22 @@ def printSchema(self):
"""
print (self._jdf.schema().treeString())

def explain(self, extended=False):
"""
Prints the plans (logical and physical) to the console for
debugging purpose.

If extended is False, only prints the physical plan.
"""
self._jdf.explain(extended)

def isLocal(self):
"""
Returns True if the `collect` and `take` methods can be run locally
(without any Spark executors).
"""
return self._jdf.isLocal()

def show(self):
"""
Print the first 20 rows.
Expand All @@ -247,14 +263,12 @@ def show(self):
2 Alice
5 Bob
>>> df
age name
2 Alice
5 Bob
DataFrame[age: int, name: string]
"""
print (self)
print self._jdf.showString().encode('utf8', 'ignore')

def __repr__(self):
return self._jdf.showString()
return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))

def count(self):
"""Return the number of elements in this RDD.
Expand Down Expand Up @@ -336,13 +350,40 @@ def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition.

It's a shorthand for df.rdd.mapPartitions()

>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(iterator): yield 1
>>> rdd.mapPartitions(f).sum()
4
"""
return self.rdd.mapPartitions(f, preservesPartitioning)

def foreach(self, f):
"""
Applies a function to all rows of this DataFrame.

It's a shorthand for df.rdd.foreach()

>>> def f(person):
... print person.name
>>> df.foreach(f)
"""
return self.rdd.foreach(f)

def foreachPartition(self, f):
"""
Applies a function to each partition of this DataFrame.

It's a shorthand for df.rdd.foreachPartition()

>>> def f(people):
... for person in people:
... print person.name
>>> df.foreachPartition(f)
"""
return self.rdd.foreachPartition(f)

def cache(self):
""" Persist with the default storage level (C{MEMORY_ONLY_SER}).
"""
Expand Down Expand Up @@ -377,8 +418,13 @@ def repartition(self, numPartitions):
""" Return a new :class:`DataFrame` that has exactly `numPartitions`
partitions.
"""
rdd = self._jdf.repartition(numPartitions, None)
return DataFrame(rdd, self.sql_ctx)
return DataFrame(self._jdf.repartition(numPartitions, None), self.sql_ctx)

def distinct(self):
"""
Return a new :class:`DataFrame` containing the distinct rows in this DataFrame.
"""
return DataFrame(self._jdf.distinct(), self.sql_ctx)

def sample(self, withReplacement, fraction, seed=None):
"""
Expand Down Expand Up @@ -957,10 +1003,7 @@ def cast(self, dataType):
return Column(jc, self.sql_ctx)

def __repr__(self):
if self._jdf.isComputable():
return self._jdf.samples()
else:
return 'Column<%s>' % self._jdf.toString()
return 'Column<%s>' % self._jdf.toString().encode('utf8')

def toPandas(self):
"""
Expand Down
12 changes: 5 additions & 7 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _create_function(name, doc=""):
""" Create a function for aggregator by name"""
def _(col):
sc = SparkContext._active_spark_context
jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
return Column(jc)
_.__name__ = name
_.__doc__ = doc
Expand Down Expand Up @@ -140,6 +140,7 @@ def __call__(self, *cols):
def udf(f, returnType=StringType()):
"""Create a user defined function (UDF)

>>> from pyspark.sql.types import IntegerType
>>> slen = udf(lambda s: len(s), IntegerType())
>>> df.select(slen(df.name).alias('slen')).collect()
[Row(slen=5), Row(slen=3)]
Expand All @@ -151,17 +152,14 @@ def _test():
import doctest
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.dataframe
globs = pyspark.sql.dataframe.__dict__.copy()
import pyspark.sql.functions
globs = pyspark.sql.functions.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
pyspark.sql.functions, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count:
Expand Down