Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-41002][CONNECT][PYTHON] Compatible take, head and first API in Python client #38488

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 57 additions & 4 deletions python/pyspark/sql/connect/dataframe.py
Expand Up @@ -24,6 +24,7 @@
Tuple,
Union,
TYPE_CHECKING,
overload,
)

import pandas
Expand Down Expand Up @@ -211,14 +212,66 @@ def filter(self, condition: Expression) -> "DataFrame":
plan.Filter(child=self._plan, filter=condition), session=self._session
)

def first(self) -> Optional["pandas.DataFrame"]:
return self.head(1)
def first(self) -> Optional[Row]:
"""Returns the first row as a :class:`Row`.

.. versionadded:: 3.4.0

Returns
-------
:class:`Row`
First row if :class:`DataFrame` is not empty, otherwise ``None``.
"""
return self.head()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this is copied from PySpark, but isn't it better to signal intent here with a 1 as explicit param?

Copy link
Contributor Author

@amaliujia amaliujia Nov 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is implementation details though but updated to self.head(1).

Copy link
Contributor Author

@amaliujia amaliujia Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah actually we cannot. self.head() returns Optional[Row] but self.head(n) returns List[Row]. self.head() is to make sure mypy check pass.


def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame:
return GroupingFrame(self, *cols)

def head(self, n: int) -> Optional["pandas.DataFrame"]:
return self.limit(n).toPandas()
@overload
def head(self) -> Optional[Row]:
...

@overload
def head(self, n: int) -> List[Row]:
...

def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]:
"""Returns the first ``n`` rows.

.. versionadded:: 3.4.0

Parameters
----------
n : int, optional
default 1. Number of rows to return.

Returns
-------
If n is greater than 1, return a list of :class:`Row`.
If n is 1, return a single Row.
"""
if n is None:
rs = self.head(1)
return rs[0] if rs else None
return self.take(n)

def take(self, num: int) -> List[Row]:
"""Returns the first ``num`` rows as a :class:`list` of :class:`Row`.

.. versionadded:: 3.4.0

Parameters
----------
num : int
Number of records to return. Will return this number of records
or whataver number is available.

Returns
-------
list
List of rows
"""
return self.limit(num).collect()

# TODO: extend `on` to also be type List[ColumnRef].
def join(
Expand Down
33 changes: 30 additions & 3 deletions python/pyspark/sql/tests/connect/test_connect_basic.py
Expand Up @@ -43,6 +43,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase):
if have_pandas:
connect: RemoteSparkSession
tbl_name: str
tbl_name_empty: str
df_text: "DataFrame"

@classmethod
Expand All @@ -58,6 +59,7 @@ def setUpClass(cls: Any):
cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF()

cls.tbl_name = "test_connect_basic_table_1"
cls.tbl_name_empty = "test_connect_basic_table_empty"

# Cleanup test data
cls.spark_connect_clean_up_test_data()
Expand All @@ -76,10 +78,21 @@ def spark_connect_load_test_data(cls: Any):
# Since we might create multiple Spark sessions, we need to create global temporary view
# that is specifically maintained in the "global_temp" schema.
df.write.saveAsTable(cls.tbl_name)
empty_table_schema = StructType(
[
StructField("firstname", StringType(), True),
StructField("middlename", StringType(), True),
StructField("lastname", StringType(), True),
]
)
emptyRDD = cls.spark.sparkContext.emptyRDD()
empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema)
empty_df.write.saveAsTable(cls.tbl_name_empty)

@classmethod
def spark_connect_clean_up_test_data(cls: Any) -> None:
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name))
cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty))


class SparkConnectTests(SparkConnectSQLTestCase):
Expand Down Expand Up @@ -138,9 +151,23 @@ def test_limit_offset(self):

def test_head(self):
df = self.connect.read.table(self.tbl_name)
pd = df.head(10)
self.assertIsNotNone(pd)
self.assertEqual(10, len(pd.index))
self.assertIsNotNone(len(df.head()))
self.assertIsNotNone(len(df.head(1)))
self.assertIsNotNone(len(df.head(5)))
df2 = self.connect.read.table(self.tbl_name_empty)
self.assertIsNone(df2.head())

def test_first(self):
df = self.connect.read.table(self.tbl_name)
self.assertIsNotNone(len(df.first()))
df2 = self.connect.read.table(self.tbl_name_empty)
self.assertIsNone(df2.first())

def test_take(self) -> None:
df = self.connect.read.table(self.tbl_name)
self.assertEqual(5, len(df.take(5)))
df2 = self.connect.read.table(self.tbl_name_empty)
self.assertEqual(0, len(df2.take(5)))

def test_range(self):
self.assertTrue(
Expand Down