-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-43474] [SS] [CONNECT] Add a spark connect access to runtime Dataframes by ID. #41580
Changes from 7 commits
e50fa97
e0c711d
647ca67
1ca1f3a
5d25a32
8a05f9b
dde3e36
ef62238
d3c2503
74b98ca
b6e8f3d
bb2ad33
5af56c9
5c0c2de
6dca3e6
8de5197
dbca798
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.connect.service | ||
|
||
import javax.annotation.concurrent.GuardedBy | ||
|
||
import scala.collection.mutable | ||
|
||
import org.apache.spark.sql.DataFrame | ||
import org.apache.spark.sql.SparkSession | ||
import org.apache.spark.sql.connect.common.InvalidPlanInput | ||
|
||
/** | ||
* This class caches DataFrame on the server side with given ids. The Spark Connect client can | ||
* create a DataFrame reference with the id. When server transforms the DataFrame reference, it | ||
* finds the DataFrame from the cache and replace the reference. | ||
* | ||
* Each session has a corresponding DataFrame map. A cached DataFrame can only be accessed from | ||
* within the same session. The DataFrame will be removed from the cache by the owner (e.g. | ||
* Streaming query) or when the session expires. | ||
*/ | ||
private[sql] class SparkConnectCachedDataFrameManager { | ||
|
||
// Session.sessionUUID -> Map[DF Reference ID -> DF] | ||
@GuardedBy("this") | ||
private val dataFrameCache = mutable.Map[String, mutable.Map[String, DataFrame]]() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once you move this into the session holder, please use a concurrent hashmap instead because you only need one then. |
||
|
||
def put(session: SparkSession, dataFrameId: String, value: DataFrame): Unit = synchronized { | ||
dataFrameCache | ||
.getOrElseUpdate(session.sessionUUID, mutable.Map[String, DataFrame]()) | ||
.put(dataFrameId, value) | ||
} | ||
|
||
def get(session: SparkSession, dataFrameId: String): DataFrame = synchronized { | ||
val sessionKey = session.sessionUUID | ||
dataFrameCache | ||
.get(sessionKey) | ||
.flatMap(_.get(dataFrameId)) | ||
.getOrElse { | ||
throw InvalidPlanInput( | ||
s"No DataFrame found in the server for id $dataFrameId in the session $sessionKey") | ||
} | ||
} | ||
|
||
def remove(session: SparkSession): Unit = synchronized { | ||
dataFrameCache.remove(session.sessionUUID) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.connect.service | ||
|
||
import org.apache.spark.sql.connect.common.InvalidPlanInput | ||
import org.apache.spark.sql.test.SharedSparkSession | ||
|
||
class SparkConnectCachedDataFrameManagerSuite extends SharedSparkSession { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it makes sense to add tests that when you have different sessions with different users that you don't get access to a "guessed" cached plan. This is the key part for security. |
||
|
||
test("Successful put and get") { | ||
val spark = this.spark | ||
import spark.implicits._ | ||
|
||
val cachedDataFrameManager = new SparkConnectCachedDataFrameManager() | ||
|
||
val key1 = "key_1" | ||
val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3")) | ||
val df1 = data1.toDF() | ||
cachedDataFrameManager.put(spark, key1, df1) | ||
|
||
val expectedDf1 = cachedDataFrameManager.get(spark, key1) | ||
assert(expectedDf1 == df1) | ||
|
||
val key2 = "key_2" | ||
val data2 = Seq(("k4", "v4"), ("k5", "v5")) | ||
val df2 = data2.toDF() | ||
cachedDataFrameManager.put(spark, key2, df2) | ||
|
||
val expectedDf2 = cachedDataFrameManager.get(spark, key2) | ||
assert(expectedDf2 == df2) | ||
} | ||
|
||
test("Get cache that does not exist should fail") { | ||
val spark = this.spark | ||
import spark.implicits._ | ||
|
||
val cachedDataFrameManager = new SparkConnectCachedDataFrameManager() | ||
|
||
val key1 = "key_1" | ||
|
||
assertThrows[InvalidPlanInput] { | ||
cachedDataFrameManager.get(spark, key1) | ||
} | ||
|
||
val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3")) | ||
val df1 = data1.toDF() | ||
cachedDataFrameManager.put(spark, key1, df1) | ||
cachedDataFrameManager.get(spark, key1) | ||
|
||
val key2 = "key_2" | ||
assertThrows[InvalidPlanInput] { | ||
cachedDataFrameManager.get(spark, key2) | ||
} | ||
} | ||
|
||
test("Remove cache and then get should fail") { | ||
val spark = this.spark | ||
import spark.implicits._ | ||
|
||
val cachedDataFrameManager = new SparkConnectCachedDataFrameManager() | ||
|
||
val key1 = "key_1" | ||
val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3")) | ||
val df1 = data1.toDF() | ||
cachedDataFrameManager.put(spark, key1, df1) | ||
cachedDataFrameManager.get(spark, key1) | ||
|
||
cachedDataFrameManager.remove(spark) | ||
assertThrows[InvalidPlanInput] { | ||
cachedDataFrameManager.get(spark, key1) | ||
} | ||
} | ||
} |
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw the session is already a class member of SparkConnectPlanner.