Skip to content

Commit

Permalink
[SPARK-41002][CONNECT][PYTHON] Compatible take, head and first
Browse files Browse the repository at this point in the history
…API in Python client

### What changes were proposed in this pull request?

1. Add `take(n)` API.
2. Change `head(n)` API to return `Union[Optional[Row], List[Row]]`.
3. Update `first()` to return `Optional[Row]`.

### Why are the changes needed?

Improve API coverage.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

UT

Closes apache#38488 from amaliujia/SPARK-41002.

Authored-by: Rui Wang <rui.wang@databricks.com>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
amaliujia authored and SandishKumarHN committed Dec 12, 2022
1 parent caa4b86 commit cde382e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 7 deletions.
61 changes: 57 additions & 4 deletions python/pyspark/sql/connect/dataframe.py
Expand Up @@ -24,6 +24,7 @@
Tuple,
Union,
TYPE_CHECKING,
overload,
)

import pandas
Expand Down Expand Up @@ -211,14 +212,66 @@ def filter(self, condition: Expression) -> "DataFrame":
plan.Filter(child=self._plan, filter=condition), session=self._session
)

def first(self) -> Optional["pandas.DataFrame"]:
return self.head(1)
def first(self) -> Optional[Row]:
"""Returns the first row as a :class:`Row`.
.. versionadded:: 3.4.0
Returns
-------
:class:`Row`
First row if :class:`DataFrame` is not empty, otherwise ``None``.
"""
return self.head()

def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame:
return GroupingFrame(self, *cols)

def head(self, n: int) -> Optional["pandas.DataFrame"]:
return self.limit(n).toPandas()
@overload
def head(self) -> Optional[Row]:
...

@overload
def head(self, n: int) -> List[Row]:
...

def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]:
"""Returns the first ``n`` rows.
.. versionadded:: 3.4.0
Parameters
----------
n : int, optional
default 1. Number of rows to return.
Returns
-------
If n is greater than 1, return a list of :class:`Row`.
If n is 1, return a single Row.
"""
if n is None:
rs = self.head(1)
return rs[0] if rs else None
return self.take(n)

def take(self, num: int) -> List[Row]:
"""Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
.. versionadded:: 3.4.0
Parameters
----------
num : int
Number of records to return. Will return this number of records
or whataver number is available.
Returns
-------
list
List of rows
"""
return self.limit(num).collect()

# TODO: extend `on` to also be type List[ColumnRef].
def join(
Expand Down
36 changes: 33 additions & 3 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Expand Up @@ -46,6 +46,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
if have_pandas:
connect: RemoteSparkSession
tbl_name: str
tbl_name_empty: str
df_text: "DataFrame"

@classmethod
Expand All @@ -61,6 +62,7 @@ def setUpClass(cls: Any):
cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF()

cls.tbl_name = "test_connect_basic_table_1"
cls.tbl_name_empty = "test_connect_basic_table_empty"

# Cleanup test data
cls.spark_connect_clean_up_test_data()
Expand All @@ -80,10 +82,21 @@ def spark_connect_load_test_data(cls: Any):
# Since we might create multiple Spark sessions, we need to create global temporary view
# that is specifically maintained in the "global_temp" schema.
df.write.saveAsTable(cls.tbl_name)
empty_table_schema = StructType(
[
StructField("firstname", StringType(), True),
StructField("middlename", StringType(), True),
StructField("lastname", StringType(), True),
]
)
emptyRDD = cls.spark.sparkContext.emptyRDD()
empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema)
empty_df.write.saveAsTable(cls.tbl_name_empty)

@classmethod
def spark_connect_clean_up_test_data(cls: Any) -> None:
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))


class SparkConnectTests(SparkConnectSQLTestCase):
Expand Down Expand Up @@ -145,10 +158,27 @@ def test_sql(self):
self.assertEqual(1, len(pdf.index))

def test_head(self):
# SPARK-41002: test `head` API in Python Client
df = self.connect.read.table(self.tbl_name)
self.assertIsNotNone(len(df.head()))
self.assertIsNotNone(len(df.head(1)))
self.assertIsNotNone(len(df.head(5)))
df2 = self.connect.read.table(self.tbl_name_empty)
self.assertIsNone(df2.head())

def test_first(self):
# SPARK-41002: test `first` API in Python Client
df = self.connect.read.table(self.tbl_name)
self.assertIsNotNone(len(df.first()))
df2 = self.connect.read.table(self.tbl_name_empty)
self.assertIsNone(df2.first())

def test_take(self) -> None:
# SPARK-41002: test `take` API in Python Client
df = self.connect.read.table(self.tbl_name)
pd = df.head(10)
self.assertIsNotNone(pd)
self.assertEqual(10, len(pd.index))
self.assertEqual(5, len(df.take(5)))
df2 = self.connect.read.table(self.tbl_name_empty)
self.assertEqual(0, len(df2.take(5)))

def test_range(self):
self.assertTrue(
Expand Down

0 comments on commit cde382e

Please sign in to comment.