diff --git a/sql/api/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/api/src/main/scala/org/apache/spark/sql/SQLContext.scala index 848a0215240bb..cb58ad3ff350c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -980,7 +980,6 @@ abstract class SQLContext private[sql] (val sparkSession: SparkSession) */ private[sql] trait SQLContextCompanion { private[sql] type SQLContextImpl <: SQLContext - private[sql] type SparkContextImpl <: SparkContext /** * Get the singleton SQLContext if it exists or create a new one using the given SparkContext. @@ -994,7 +993,7 @@ private[sql] trait SQLContextCompanion { * @since 1.5.0 */ @deprecated("Use SparkSession.builder instead", "2.0.0") - def getOrCreate(sparkContext: SparkContextImpl): SQLContextImpl + def getOrCreate(sparkContext: SparkContext): SQLContextImpl /** * Changes the SQLContext that will be returned in this thread and its children when @@ -1019,3 +1018,13 @@ private[sql] trait SQLContextCompanion { SparkSession.clearActiveSession() } } + +object SQLContext extends SQLContextCompanion { + private[sql] type SQLContextImpl = SQLContext + + /** @inheritdoc */ + @deprecated("Use SparkSession.builder instead", "2.0.0") + def getOrCreate(sparkContext: SparkContext): SQLContext = { + SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext + } +} diff --git a/sql/api/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala b/sql/api/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala index 34fb507c65686..57eddd1bc69fa 100644 --- a/sql/api/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala +++ b/sql/api/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala @@ -17,22 +17,57 @@ package org.apache.spark.sql // scalastyle:off funsuite -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.scalatest.funsuite.AnyFunSuite -import org.apache.spark.sql.functions.sum +import org.apache.spark.SparkContext +import org.apache.spark.sql.functions.{max, sum} /** * Test suite for SparkSession implementation binding. */ -trait SparkSessionBuilderImplementationBindingSuite extends AnyFunSuite with BeforeAndAfterAll { +trait SparkSessionBuilderImplementationBindingSuite + extends AnyFunSuite + with BeforeAndAfterAll + with BeforeAndAfterEach { // scalastyle:on - protected def configure(builder: SparkSessionBuilder): builder.type = builder + + protected def sparkContext: SparkContext + protected def implementationPackageName: String = getClass.getPackageName + + private def assertInCorrectPackage[T](obj: T): Unit = { + assert(obj.getClass.getPackageName == implementationPackageName) + } + + override protected def beforeEach(): Unit = { + super.beforeEach() + clearSessions() + } + + override protected def afterAll(): Unit = { + clearSessions() + super.afterAll() + } + + private def clearSessions(): Unit = { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } test("range") { - val session = configure(SparkSession.builder()).getOrCreate() + val session = SparkSession.builder().getOrCreate() + assertInCorrectPackage(session) import session.implicits._ val df = session.range(10).agg(sum("id")).as[Long] assert(df.head() == 45) } + + test("sqlContext") { + SparkSession.clearActiveSession() + val ctx = SQLContext.getOrCreate(sparkContext) + assertInCorrectPackage(ctx) + import ctx.implicits._ + val df = ctx.createDataset(1 to 11).select(max("value").as[Long]) + assert(df.head() == 11) + } } diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala index cc6bc8af1f4b3..06eb06299f4c4 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionBuilderImplementationBindingSuite.scala @@ -16,8 +16,7 @@ */ package org.apache.spark.sql.connect -import org.apache.spark.sql -import org.apache.spark.sql.SparkSessionBuilder +import org.apache.spark.{sql, SparkContext} import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession} /** @@ -27,8 +26,11 @@ class SparkSessionBuilderImplementationBindingSuite extends ConnectFunSuite with sql.SparkSessionBuilderImplementationBindingSuite with RemoteSparkSession { - override protected def configure(builder: SparkSessionBuilder): builder.type = { + override def beforeAll(): Unit = { // We need to set this configuration because the port used by the server is random. - builder.remote(s"sc://localhost:$serverPort") + System.setProperty("spark.remote", s"sc://localhost:$serverPort") + super.beforeAll() } + + override protected def sparkContext: SparkContext = null } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLContext.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLContext.scala index e38179e232d05..cc34ca6c9ffd7 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLContext.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SQLContext.scala @@ -305,29 +305,13 @@ class SQLContext private[sql] (override val sparkSession: SparkSession) super.jdbc(url, table, theParts) } } + object SQLContext extends sql.SQLContextCompanion { override private[sql] type SQLContextImpl = SQLContext - override private[sql] type SparkContextImpl = SparkContext - /** - * Get the singleton SQLContext if it exists or create a new one. - * - * This function can be used to create a singleton SQLContext object that can be shared across - * the JVM. - * - * If there is an active SQLContext for current thread, it will be returned instead of the - * global one. - * - * @param sparkContext - * The SparkContext. This parameter is not used in Spark Connect. - * - * @since 4.0.0 - */ + /** @inheritdoc */ def getOrCreate(sparkContext: SparkContext): SQLContext = { SparkSession.builder().getOrCreate().sqlContext } - - /** @inheritdoc */ - override def setActive(sqlContext: SQLContext): Unit = super.setActive(sqlContext) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala index 2d5d26fe6016e..18a84d8c4299a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SQLContext.scala @@ -378,18 +378,13 @@ class SQLContext private[sql] (override val sparkSession: SparkSession) } object SQLContext extends sql.SQLContextCompanion { - override private[sql] type SQLContextImpl = SQLContext - override private[sql] type SparkContextImpl = SparkContext /** @inheritdoc */ def getOrCreate(sparkContext: SparkContext): SQLContext = { newSparkSessionBuilder().sparkContext(sparkContext).getOrCreate().sqlContext } - /** @inheritdoc */ - override def setActive(sqlContext: SQLContext): Unit = super.setActive(sqlContext) - /** * Converts an iterator of Java Beans to InternalRow using the provided bean info & schema. This * is not related to the singleton, but is a static method for internal use.