Skip to content

Commit

Permalink
[SPARK-37291][PYSPARK][FOLLOWUP] PySpark create SparkSession should p…
Browse files Browse the repository at this point in the history
…ass initialSessionOptions

### What changes were proposed in this pull request?
In this pr, when create SparkSession, we pass initialSessionOptions to SparkSession, to keep same code path with scala code.

### Why are the changes needed?
Keep same code path with scala code.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existed UT

Closes #34732 from AngersZhuuuu/SPARK-37291-FOLLOWUP.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
AngersZhuuuu authored and HyukjinKwon committed Nov 30, 2021
1 parent e031d00 commit 1a43112
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
7 changes: 2 additions & 5 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __init__(
self,
sparkContext: SparkContext,
jsparkSession: Optional[JavaObject] = None,
options: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = {},
):
from pyspark.sql.context import SQLContext

Expand All @@ -305,10 +305,7 @@ def __init__(
):
jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
else:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())
if options is not None:
for key, value in options.items():
jsparkSession.sharedState().conf().set(key, value)
jsparkSession = self._jvm.SparkSession(self._jsc.sc(), options)
self._jsparkSession = jsparkSession
self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
Expand Down
11 changes: 8 additions & 3 deletions python/pyspark/sql/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,23 @@ def test_another_spark_session(self):
if session2 is not None:
session2.stop()

def test_create_spark_context_first_and_copy_options_to_sharedState(self):
def test_create_spark_context_with_initial_session_options(self):
sc = None
session = None
try:
conf = SparkConf().set("key1", "value1")
sc = SparkContext("local[4]", "SessionBuilderTests", conf=conf)
session = (
SparkSession.builder.config("key2", "value2").enableHiveSupport().getOrCreate()
SparkSession.builder.config("spark.sql.codegen.comments", "true")
.enableHiveSupport()
.getOrCreate()
)

self.assertEqual(session._jsparkSession.sharedState().conf().get("key1"), "value1")
self.assertEqual(session._jsparkSession.sharedState().conf().get("key2"), "value2")
self.assertEqual(
session._jsparkSession.sharedState().conf().get("spark.sql.codegen.comments"),
"true",
)
self.assertEqual(
session._jsparkSession.sharedState().conf().get("spark.sql.catalogImplementation"),
"hive",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,17 @@ class SparkSession private(
* since that would cause every new session to reinvoke Spark Session Extensions on the currently
* running extensions.
*/
private[sql] def this(sc: SparkContext) = {
private[sql] def this(
sc: SparkContext,
initialSessionOptions: java.util.HashMap[String, String]) = {
this(sc, None, None,
SparkSession.applyExtensions(
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
new SparkSessionExtensions), Map.empty)
new SparkSessionExtensions), initialSessionOptions.asScala.toMap)
}

private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]())

private[sql] val sessionUUID: String = UUID.randomUUID.toString

sparkContext.assertNotStopped()
Expand Down

0 comments on commit 1a43112

Please sign in to comment.