Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43474] [SS] [CONNECT] Add a spark connect access to runtime Dataframes by ID. #41580

Closed
wants to merge 17 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ message Relation {
ApplyInPandasWithState apply_in_pandas_with_state = 34;
HtmlString html_string = 35;
CachedLocalRelation cached_local_relation = 36;
CachedRemoteRelation cached_remote_relation = 37;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -395,6 +396,12 @@ message CachedLocalRelation {
string hash = 3;
}

// Represents a remote relation that has been cached on server.
message CachedRemoteRelation {
// (Required) ID of the remote related (assigned by the service).
string relation_id = 1;
}

// Relation of type [[Sample]] that samples a fraction of the dataset.
message Sample {
// (Required) Input relation for a Sample.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class SparkConnectPlanner(val session: SparkSession) extends Logging {
transformCoGroupMap(rel.getCoGroupMap)
case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
transformApplyInPandasWithState(rel.getApplyInPandasWithState)
case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
transformCachedRemoteRelation(session, rel.getCachedRemoteRelation)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
Expand Down Expand Up @@ -788,6 +790,14 @@ class SparkConnectPlanner(val session: SparkSession) extends Logging {
.logicalPlan
}

private def transformCachedRemoteRelation(
session: SparkSession,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw the session is already a class member of SparkConnectPlanner.

rel: proto.CachedRemoteRelation): LogicalPlan = {
SparkConnectService.cachedDataFrameManager
.get(session, rel.getRelationId)
.logicalPlan
rangadi marked this conversation as resolved.
Show resolved Hide resolved
}

private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.common.InvalidPlanInput

/**
* This class caches DataFrame on the server side with given ids. The Spark Connect client can
* create a DataFrame reference with the id. When server transforms the DataFrame reference, it
* finds the DataFrame from the cache and replace the reference.
*
* Each session has a corresponding DataFrame map. A cached DataFrame can only be accessed from
* within the same session. The DataFrame will be removed from the cache by the owner (e.g.
* Streaming query) or when the session expires.
*/
private[sql] class SparkConnectCachedDataFrameManager {

// Session.sessionUUID -> Map[DF Reference ID -> DF]
@GuardedBy("this")
private val dataFrameCache = mutable.Map[String, mutable.Map[String, DataFrame]]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once you move this into the session holder, please use a concurrent hashmap instead because you only need one then.


def put(session: SparkSession, dataFrameId: String, value: DataFrame): Unit = synchronized {
dataFrameCache
.getOrElseUpdate(session.sessionUUID, mutable.Map[String, DataFrame]())
.put(dataFrameId, value)
}

def get(session: SparkSession, dataFrameId: String): DataFrame = synchronized {
val sessionKey = session.sessionUUID
dataFrameCache
.get(sessionKey)
.flatMap(_.get(dataFrameId))
.getOrElse {
throw InvalidPlanInput(
s"No DataFrame found in the server for id $dataFrameId in the session $sessionKey")
}
}

def remove(session: SparkSession): Unit = synchronized {
dataFrameCache.remove(session.sessionUUID)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -298,12 +298,15 @@ object SparkConnectService {
userSessionMapping.getIfPresent((userId, sessionId))
})

private[connect] val cachedDataFrameManager = new SparkConnectCachedDataFrameManager()

private class RemoveSessionListener extends RemovalListener[SessionCacheKey, SessionHolder] {
override def onRemoval(
notification: RemovalNotification[SessionCacheKey, SessionHolder]): Unit = {
val SessionHolder(userId, sessionId, session) = notification.getValue
val blockManager = session.sparkContext.env.blockManager
blockManager.removeCache(userId, sessionId)
cachedDataFrameManager.remove(session)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.service

import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.test.SharedSparkSession

class SparkConnectCachedDataFrameManagerSuite extends SharedSparkSession {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense to add tests that when you have different sessions with different users that you don't get access to a "guessed" cached plan. This is the key part for security.


test("Successful put and get") {
val spark = this.spark
import spark.implicits._

val cachedDataFrameManager = new SparkConnectCachedDataFrameManager()

val key1 = "key_1"
val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))
val df1 = data1.toDF()
cachedDataFrameManager.put(spark, key1, df1)

val expectedDf1 = cachedDataFrameManager.get(spark, key1)
assert(expectedDf1 == df1)

val key2 = "key_2"
val data2 = Seq(("k4", "v4"), ("k5", "v5"))
val df2 = data2.toDF()
cachedDataFrameManager.put(spark, key2, df2)

val expectedDf2 = cachedDataFrameManager.get(spark, key2)
assert(expectedDf2 == df2)
}

test("Get cache that does not exist should fail") {
val spark = this.spark
import spark.implicits._

val cachedDataFrameManager = new SparkConnectCachedDataFrameManager()

val key1 = "key_1"

assertThrows[InvalidPlanInput] {
cachedDataFrameManager.get(spark, key1)
}

val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))
val df1 = data1.toDF()
cachedDataFrameManager.put(spark, key1, df1)
cachedDataFrameManager.get(spark, key1)

val key2 = "key_2"
assertThrows[InvalidPlanInput] {
cachedDataFrameManager.get(spark, key2)
}
}

test("Remove cache and then get should fail") {
val spark = this.spark
import spark.implicits._

val cachedDataFrameManager = new SparkConnectCachedDataFrameManager()

val key1 = "key_1"
val data1 = Seq(("k1", "v1"), ("k2", "v2"), ("k3", "v3"))
val df1 = data1.toDF()
cachedDataFrameManager.put(spark, key1, df1)
cachedDataFrameManager.get(spark, key1)

cachedDataFrameManager.remove(spark)
assertThrows[InvalidPlanInput] {
cachedDataFrameManager.get(spark, key1)
}
}
}
14 changes: 14 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,20 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan


class CachedRemoteRelation(LogicalPlan):
"""Logical plan object for a DataFrame reference which represents a DataFrame that's been
cached on the server with a given id."""

def __init__(self, relationId: str):
super().__init__(None)
self._relationId = relationId

def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
plan.cached_remote_relation.relationId = self._relationId
return plan


class Hint(LogicalPlan):
"""Logical plan object for a Hint operation."""

Expand Down
270 changes: 142 additions & 128 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Relation(google.protobuf.message.Message):
APPLY_IN_PANDAS_WITH_STATE_FIELD_NUMBER: builtins.int
HTML_STRING_FIELD_NUMBER: builtins.int
CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
CACHED_REMOTE_RELATION_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -185,6 +186,8 @@ class Relation(google.protobuf.message.Message):
@property
def cached_local_relation(self) -> global___CachedLocalRelation: ...
@property
def cached_remote_relation(self) -> global___CachedRemoteRelation: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -257,6 +260,7 @@ class Relation(google.protobuf.message.Message):
apply_in_pandas_with_state: global___ApplyInPandasWithState | None = ...,
html_string: global___HtmlString | None = ...,
cached_local_relation: global___CachedLocalRelation | None = ...,
cached_remote_relation: global___CachedRemoteRelation | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand All @@ -283,6 +287,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"cached_local_relation",
b"cached_local_relation",
"cached_remote_relation",
b"cached_remote_relation",
"catalog",
b"catalog",
"co_group_map",
Expand Down Expand Up @@ -390,6 +396,8 @@ class Relation(google.protobuf.message.Message):
b"approx_quantile",
"cached_local_relation",
b"cached_local_relation",
"cached_remote_relation",
b"cached_remote_relation",
"catalog",
b"catalog",
"co_group_map",
Expand Down Expand Up @@ -524,6 +532,7 @@ class Relation(google.protobuf.message.Message):
"apply_in_pandas_with_state",
"html_string",
"cached_local_relation",
"cached_remote_relation",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -1593,6 +1602,25 @@ class CachedLocalRelation(google.protobuf.message.Message):

global___CachedLocalRelation = CachedLocalRelation

class CachedRemoteRelation(google.protobuf.message.Message):
"""Represents a remote relation that has been cached on server."""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

RELATION_ID_FIELD_NUMBER: builtins.int
relation_id: builtins.str
"""(Required) ID of the remote related (assigned by the service)."""
def __init__(
self,
*,
relation_id: builtins.str = ...,
) -> None: ...
def ClearField(
self, field_name: typing_extensions.Literal["relation_id", b"relation_id"]
) -> None: ...

global___CachedRemoteRelation = CachedRemoteRelation

class Sample(google.protobuf.message.Message):
"""Relation of type [[Sample]] that samples a fraction of the dataset."""

Expand Down