diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index b9ddb0db30028..dafde640e395f 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -24,6 +24,7 @@ Tuple, Union, TYPE_CHECKING, + overload, ) import pandas @@ -211,14 +212,66 @@ def filter(self, condition: Expression) -> "DataFrame": plan.Filter(child=self._plan, filter=condition), session=self._session ) - def first(self) -> Optional["pandas.DataFrame"]: - return self.head(1) + def first(self) -> Optional[Row]: + """Returns the first row as a :class:`Row`. + + .. versionadded:: 3.4.0 + + Returns + ------- + :class:`Row` + First row if :class:`DataFrame` is not empty, otherwise ``None``. + """ + return self.head() def groupBy(self, *cols: "ColumnOrString") -> GroupingFrame: return GroupingFrame(self, *cols) - def head(self, n: int) -> Optional["pandas.DataFrame"]: - return self.limit(n).toPandas() + @overload + def head(self) -> Optional[Row]: + ... + + @overload + def head(self, n: int) -> List[Row]: + ... + + def head(self, n: Optional[int] = None) -> Union[Optional[Row], List[Row]]: + """Returns the first ``n`` rows. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + n : int, optional + default 1. Number of rows to return. + + Returns + ------- + If n is greater than 1, return a list of :class:`Row`. + If n is 1, return a single Row. + """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def take(self, num: int) -> List[Row]: + """Returns the first ``num`` rows as a :class:`list` of :class:`Row`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + num : int + Number of records to return. Will return this number of records + or whataver number is available. + + Returns + ------- + list + List of rows + """ + return self.limit(num).collect() # TODO: extend `on` to also be type List[ColumnRef]. def join( diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 0d3fc76134eb7..89d9e0039bf5c 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -43,6 +43,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): if have_pandas: connect: RemoteSparkSession tbl_name: str + tbl_name_empty: str df_text: "DataFrame" @classmethod @@ -58,6 +59,7 @@ def setUpClass(cls: Any): cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF() cls.tbl_name = "test_connect_basic_table_1" + cls.tbl_name_empty = "test_connect_basic_table_empty" # Cleanup test data cls.spark_connect_clean_up_test_data() @@ -76,10 +78,21 @@ def spark_connect_load_test_data(cls: Any): # Since we might create multiple Spark sessions, we need to create global temporary view # that is specifically maintained in the "global_temp" schema. df.write.saveAsTable(cls.tbl_name) + empty_table_schema = StructType( + [ + StructField("firstname", StringType(), True), + StructField("middlename", StringType(), True), + StructField("lastname", StringType(), True), + ] + ) + emptyRDD = cls.spark.sparkContext.emptyRDD() + empty_df = cls.spark.createDataFrame(emptyRDD, empty_table_schema) + empty_df.write.saveAsTable(cls.tbl_name_empty) @classmethod def spark_connect_clean_up_test_data(cls: Any) -> None: cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name)) + cls.spark.sql("DROP TABLE IF EXISTS {}".format(cls.tbl_name_empty)) class SparkConnectTests(SparkConnectSQLTestCase): @@ -137,10 +150,27 @@ def test_limit_offset(self): self.assertEqual(2, len(pd2.index)) def test_head(self): + # SPARK-41002: test `head` API in Python Client + df = self.connect.read.table(self.tbl_name) + self.assertIsNotNone(len(df.head())) + self.assertIsNotNone(len(df.head(1))) + self.assertIsNotNone(len(df.head(5))) + df2 = self.connect.read.table(self.tbl_name_empty) + self.assertIsNone(df2.head()) + + def test_first(self): + # SPARK-41002: test `first` API in Python Client + df = self.connect.read.table(self.tbl_name) + self.assertIsNotNone(len(df.first())) + df2 = self.connect.read.table(self.tbl_name_empty) + self.assertIsNone(df2.first()) + + def test_take(self) -> None: + # SPARK-41002: test `take` API in Python Client df = self.connect.read.table(self.tbl_name) - pd = df.head(10) - self.assertIsNotNone(pd) - self.assertEqual(10, len(pd.index)) + self.assertEqual(5, len(df.take(5))) + df2 = self.connect.read.table(self.tbl_name_empty) + self.assertEqual(0, len(df2.take(5))) def test_range(self): self.assertTrue(