From b6cb5f87a34114a0fce3372811707db03dd48813 Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Wed, 1 Oct 2025 10:34:47 -0700 Subject: [PATCH 1/3] init --- python/pyspark/sql/connect/session.py | 53 +++++++++++++++++++ .../sql/tests/connect/test_connect_session.py | 26 +++++++++ 2 files changed, 79 insertions(+) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 6ccffc718d064..e623f94099997 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -26,6 +26,7 @@ from collections.abc import Callable, Sized import functools from threading import RLock +from types import TracebackType from typing import ( Optional, Any, @@ -40,6 +41,7 @@ Mapping, TYPE_CHECKING, ClassVar, + Type, ) import numpy as np @@ -947,6 +949,57 @@ def stop(self) -> None: if "SPARK_REMOTE" in os.environ: del os.environ["SPARK_REMOTE"] + def __enter__(self) -> "SparkSession": + """ + Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax. + + .. versionadded:: 2.0.0 + + Examples + -------- + >>> with SparkSession.builder.master("local").getOrCreate() as session: + ... session.range(5).show() # doctest: +SKIP + +---+ + | id| + +---+ + | 0| + | 1| + | 2| + | 3| + | 4| + +---+ + """ + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """ + Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax. + + Specifically stop the SparkSession on exit of the with block. + + .. versionadded:: 2.0.0 + + Examples + -------- + >>> with SparkSession.builder.master("local").getOrCreate() as session: + ... session.range(5).show() # doctest: +SKIP + +---+ + | id| + +---+ + | 0| + | 1| + | 2| + | 3| + | 4| + +---+ + """ + self.stop() + @property def is_stopped(self) -> bool: """ diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index 1857796ac9aa0..21e4cb831c86d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -324,6 +324,32 @@ def test_config(self): self.assertEqual(self.spark.conf.get("boolean"), "false") self.assertEqual(self.spark.conf.get("integer"), "1") + def test_context_manager_enter_exit(self): + """Test that SparkSession works as a context manager.""" + # Create a new session for testing + with PySparkSession.builder.remote("local[2]").getOrCreate() as session: + self.assertIsInstance(session, PySparkSession) + self.assertFalse(session.is_stopped) + + df = session.range(3) + result = df.collect() + self.assertEqual(len(result), 3) + + self.assertTrue(session.is_stopped) + + def test_context_manager_with_exception(self): + """Test that SparkSession is properly stopped even when exception occurs.""" + session = None + try: + with PySparkSession.builder.remote("local[2]").getOrCreate() as session: + self.assertIsInstance(session, PySparkSession) + self.assertFalse(session.is_stopped) + raise ValueError("Test exception") + except ValueError: + pass # Expected exception + + self.assertTrue(session.is_stopped) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_session import * # noqa: F401 From 321ceb5113637c973ff8ac6899fa5288e1a2d9ad Mon Sep 17 00:00:00 2001 From: Shujing Yang Date: Wed, 1 Oct 2025 16:06:52 -0700 Subject: [PATCH 2/3] fix --- python/pyspark/sql/tests/connect/test_connect_session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py b/python/pyspark/sql/tests/connect/test_connect_session.py index 21e4cb831c86d..98b1dfc43dd01 100644 --- a/python/pyspark/sql/tests/connect/test_connect_session.py +++ b/python/pyspark/sql/tests/connect/test_connect_session.py @@ -328,7 +328,7 @@ def test_context_manager_enter_exit(self): """Test that SparkSession works as a context manager.""" # Create a new session for testing with PySparkSession.builder.remote("local[2]").getOrCreate() as session: - self.assertIsInstance(session, PySparkSession) + self.assertIsInstance(session, RemoteSparkSession) self.assertFalse(session.is_stopped) df = session.range(3) @@ -342,7 +342,7 @@ def test_context_manager_with_exception(self): session = None try: with PySparkSession.builder.remote("local[2]").getOrCreate() as session: - self.assertIsInstance(session, PySparkSession) + self.assertIsInstance(session, RemoteSparkSession) self.assertFalse(session.is_stopped) raise ValueError("Test exception") except ValueError: From 86e558c0ebd4acbba50f4ebbdf623e2a00764fea Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 2 Oct 2025 08:50:02 +0900 Subject: [PATCH 3/3] Apply suggestions from code review --- python/pyspark/sql/connect/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index e623f94099997..64f095a5b018b 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -953,7 +953,7 @@ def __enter__(self) -> "SparkSession": """ Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax. - .. versionadded:: 2.0.0 + .. versionadded:: 4.1.0 Examples -------- @@ -982,7 +982,7 @@ def __exit__( Specifically stop the SparkSession on exit of the with block. - .. versionadded:: 2.0.0 + .. versionadded:: 4.1.0 Examples --------