From b89376d935d2dd192b5d90f2596fe35e368b32cf Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 11 Aug 2023 11:28:01 +0900 Subject: [PATCH 1/2] Make TimestampNTZ works with literals in Python Spark Connect --- .../pandas/tests/connect/test_parity_resample.py | 4 +--- python/pyspark/sql/connect/expressions.py | 3 +++ python/pyspark/sql/functions.py | 12 ++++++++++++ python/pyspark/sql/utils.py | 13 +++++++++++-- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/python/pyspark/pandas/tests/connect/test_parity_resample.py b/python/pyspark/pandas/tests/connect/test_parity_resample.py index d5c901f113a05..caca2f957b507 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_resample.py +++ b/python/pyspark/pandas/tests/connect/test_parity_resample.py @@ -30,9 +30,7 @@ class ResampleParityTests( class ResampleWithTimezoneTests( ResampleWithTimezoneMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase ): - @unittest.skip("SPARK-44731: Support 'spark.sql.timestampType' in Python Spark Connect client") - def test_series_resample_with_timezone(self): - super().test_series_resample_with_timezone() + pass if __name__ == "__main__": diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 44e6e174f70c5..d0a9b1d69aee3 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -15,6 +15,7 @@ # limitations under the License. # from pyspark.sql.connect.utils import check_dependencies +from pyspark.sql.utils import is_timestamp_ntz_preferred check_dependencies(__name__) @@ -295,6 +296,8 @@ def _infer_type(cls, value: Any) -> DataType: return StringType() elif isinstance(value, decimal.Decimal): return DecimalType() + elif isinstance(value, datetime.datetime) and is_timestamp_ntz_preferred(): + return TimestampNTZType() elif isinstance(value, datetime.datetime): return TimestampType() elif isinstance(value, datetime.date): diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 1b8a946e02e48..fdb4ec8111ed4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -7758,6 +7758,7 @@ def check_field(field: Union[Column, str], fieldName: str) -> None: return _invoke_function("session_window", time_col, gap_duration) +@try_remote_functions def to_unix_timestamp( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, @@ -7767,6 +7768,9 @@ def to_unix_timestamp( .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- timestamp : :class:`~pyspark.sql.Column` or str @@ -7794,6 +7798,7 @@ def to_unix_timestamp( return _invoke_function_over_columns("to_unix_timestamp", timestamp) +@try_remote_functions def to_timestamp_ltz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, @@ -7804,6 +7809,9 @@ def to_timestamp_ltz( .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- timestamp : :class:`~pyspark.sql.Column` or str @@ -7831,6 +7839,7 @@ def to_timestamp_ltz( return _invoke_function_over_columns("to_timestamp_ltz", timestamp) +@try_remote_functions def to_timestamp_ntz( timestamp: "ColumnOrName", format: Optional["ColumnOrName"] = None, @@ -7841,6 +7850,9 @@ def to_timestamp_ntz( .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- timestamp : :class:`~pyspark.sql.Column` or str diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index d4f56fe822f3e..17391f0187b01 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -140,8 +140,17 @@ def is_timestamp_ntz_preferred() -> bool: """ Return a bool if TimestampNTZType is preferred according to the SQL configuration set. """ - jvm = SparkContext._jvm - return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred() + if is_remote(): + from pyspark.sql.connect.session import SparkSession + + session = SparkSession.getActiveSession() + if session is None: + return False + else: + return session.conf.get("spark.sql.timestampType", None) == "TIMESTAMP_NTZ" + else: + jvm = SparkContext._jvm + return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred() def is_remote() -> bool: From eb7598f91b5bcf9b83f8623b57486c044760d440 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 11 Aug 2023 11:37:31 +0900 Subject: [PATCH 2/2] Rename for mypy --- python/pyspark/sql/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 17391f0187b01..cb262a14cbe2c 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -141,9 +141,9 @@ def is_timestamp_ntz_preferred() -> bool: Return a bool if TimestampNTZType is preferred according to the SQL configuration set. """ if is_remote(): - from pyspark.sql.connect.session import SparkSession + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession - session = SparkSession.getActiveSession() + session = ConnectSparkSession.getActiveSession() if session is None: return False else: