Skip to content

Commit d9c5f9d

Browse files
grundprinzipHyukjinKwon
authored andcommitted
[SPARK-45798][CONNECT] Assert server-side session ID
### What changes were proposed in this pull request? Without this patch, when the server would restart because of an abnormal condition, the client would not realize that this be the case. For example, when a driver OOM occurs and the driver is restarted, the client would not realize that the server is restarted and a new session is assigned. This patch fixes this behavior and asserts that the server side session ID does not change during the connection. If it does change it throws an exception like this: ``` >>> spark.range(10).collect() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/dataframe.py", line 1710, in collect table, schema = self._to_table() File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/dataframe.py", line 1722, in _to_table table, schema = self._session.client.to_table(query, self._plan.observations) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 839, in to_table table, schema, _, _, _ = self._execute_and_fetch(req, observations) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1295, in _execute_and_fetch for response in self._execute_and_fetch_as_iterator(req, observations): File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1273, in _execute_and_fetch_as_iterator self._handle_error(error) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1521, in _handle_error raise error File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1266, in _execute_and_fetch_as_iterator yield from handle_response(b) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1193, in handle_response self._verify_response_integrity(b) File "/Users/martin.grund/Development/spark/python/pyspark/sql/connect/client/core.py", line 1622, in _verify_response_integrity raise SparkConnectException( pyspark.errors.exceptions.connect.SparkConnectException: Received incorrect server side session identifier for request. Please restart Spark Session. (9493c83d-cfa4-488f-9522-838ef3df90bf != c5302e8f-170d-477e-908d-299927b68fd8) ``` ### Why are the changes needed? Stability ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Existing tests cover the basic invariant. - Added new tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43664 from grundprinzip/SPARK-45798. Authored-by: Martin Grund <martin.grund@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent ce818ba commit d9c5f9d

30 files changed

Lines changed: 671 additions & 252 deletions

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
4545
private var retryPolicy: GrpcRetryHandler.RetryPolicy = _
4646
private var bstub: CustomSparkConnectBlockingStub = _
4747
private var stub: CustomSparkConnectStub = _
48+
private var state: SparkConnectStubState = _
4849

4950
private def startDummyServer(): Unit = {
5051
service = new DummySparkConnectService()
@@ -58,8 +59,9 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
5859
private def createArtifactManager(): Unit = {
5960
channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build()
6061
retryPolicy = GrpcRetryHandler.RetryPolicy()
61-
bstub = new CustomSparkConnectBlockingStub(channel, retryPolicy)
62-
stub = new CustomSparkConnectStub(channel, retryPolicy)
62+
state = new SparkConnectStubState(channel, retryPolicy)
63+
bstub = new CustomSparkConnectBlockingStub(channel, state)
64+
stub = new CustomSparkConnectStub(channel, state)
6365
artifactManager = new ArtifactManager(Configuration(), "", bstub, stub)
6466
}
6567

connector/connect/common/src/main/protobuf/spark/connect/base.proto

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,12 @@ message AnalyzePlanRequest {
197197

198198
// Response to performing analysis of the query. Contains relevant metadata to be able to
199199
// reason about the performance.
200+
// Next ID: 16
200201
message AnalyzePlanResponse {
201202
string session_id = 1;
203+
// Server-side generated idempotency key that the client can use to assert that the server side
204+
// session has not changed.
205+
string server_side_session_id = 15;
202206

203207
oneof result {
204208
Schema schema = 2;
@@ -317,8 +321,12 @@ message ExecutePlanRequest {
317321

318322
// The response of a query, can be one or more for each request. Responses belonging to the
319323
// same input query, carry the same `session_id`.
324+
// Next ID: 16
320325
message ExecutePlanResponse {
321326
string session_id = 1;
327+
// Server-side generated idempotency key that the client can use to assert that the server side
328+
// session has not changed.
329+
string server_side_session_id = 15;
322330

323331
// Identifies the ExecutePlan execution.
324332
// If set by the client in ExecutePlanRequest.operationId, that value is returned.
@@ -492,8 +500,12 @@ message ConfigRequest {
492500
}
493501

494502
// Response to the config request.
503+
// Next ID: 5
495504
message ConfigResponse {
496505
string session_id = 1;
506+
// Server-side generated idempotency key that the client can use to assert that the server side
507+
// session has not changed.
508+
string server_side_session_id = 4;
497509

498510
// (Optional) The result key-value pairs.
499511
//
@@ -584,7 +596,17 @@ message AddArtifactsRequest {
584596

585597
// Response to adding an artifact. Contains relevant metadata to verify successful transfer of
586598
// artifact(s).
599+
// Next ID: 4
587600
message AddArtifactsResponse {
601+
// Session id in which the AddArtifact was running.
602+
string session_id = 2;
603+
// Server-side generated idempotency key that the client can use to assert that the server side
604+
// session has not changed.
605+
string server_side_session_id = 3;
606+
607+
// The list of artifact(s) seen by the server.
608+
repeated ArtifactSummary artifacts = 1;
609+
588610
// Metadata of an artifact.
589611
message ArtifactSummary {
590612
string name = 1;
@@ -593,9 +615,6 @@ message AddArtifactsResponse {
593615
// If false, the client may choose to resend the artifact specified by `name`.
594616
bool is_crc_successful = 2;
595617
}
596-
597-
// The list of artifact(s) seen by the server.
598-
repeated ArtifactSummary artifacts = 1;
599618
}
600619

601620
// Request to get current statuses of artifacts at the server side.
@@ -626,14 +645,20 @@ message ArtifactStatusesRequest {
626645
}
627646

628647
// Response to checking artifact statuses.
648+
// Next ID: 4
629649
message ArtifactStatusesResponse {
650+
// Session id in which the ArtifactStatus was running.
651+
string session_id = 2;
652+
// Server-side generated idempotency key that the client can use to assert that the server side
653+
// session has not changed.
654+
string server_side_session_id = 3;
655+
// A map of artifact names to their statuses.
656+
map<string, ArtifactStatus> statuses = 1;
657+
630658
message ArtifactStatus {
631659
// Exists or not particular artifact at the server.
632660
bool exists = 1;
633661
}
634-
635-
// A map of artifact names to their statuses.
636-
map<string, ArtifactStatus> statuses = 1;
637662
}
638663

639664
message InterruptRequest {
@@ -678,12 +703,17 @@ message InterruptRequest {
678703
}
679704
}
680705

706+
// Next ID: 4
681707
message InterruptResponse {
682708
// Session id in which the interrupt was running.
683709
string session_id = 1;
710+
// Server-side generated idempotency key that the client can use to assert that the server side
711+
// session has not changed.
712+
string server_side_session_id = 3;
684713

685714
// Operation ids of the executions which were interrupted.
686715
repeated string interrupted_ids = 2;
716+
687717
}
688718

689719
message ReattachOptions {
@@ -774,9 +804,13 @@ message ReleaseExecuteRequest {
774804
}
775805
}
776806

807+
// Next ID: 4
777808
message ReleaseExecuteResponse {
778809
// Session id in which the release was running.
779810
string session_id = 1;
811+
// Server-side generated idempotency key that the client can use to assert that the server side
812+
// session has not changed.
813+
string server_side_session_id = 3;
780814

781815
// Operation id of the operation on which the release executed.
782816
// If the operation couldn't be found (because e.g. it was concurrently released), will be unset.
@@ -803,9 +837,13 @@ message ReleaseSessionRequest {
803837
optional string client_type = 3;
804838
}
805839

840+
// Next ID: 3
806841
message ReleaseSessionResponse {
807842
// Session id of the session on which the release executed.
808843
string session_id = 1;
844+
// Server-side generated idempotency key that the client can use to assert that the server side
845+
// session has not changed.
846+
string server_side_session_id = 2;
809847
}
810848

811849
message FetchErrorDetailsRequest {
@@ -828,8 +866,21 @@ message FetchErrorDetailsRequest {
828866
optional string client_type = 4;
829867
}
830868

869+
// Next ID: 5
831870
message FetchErrorDetailsResponse {
832871

872+
// Server-side generated idempotency key that the client can use to assert that the server side
873+
// session has not changed.
874+
string server_side_session_id = 3;
875+
876+
string session_id = 4;
877+
878+
// The index of the root error in errors. The field will not be set if the error is not found.
879+
optional int32 root_error_idx = 1;
880+
881+
// A list of errors.
882+
repeated Error errors = 2;
883+
833884
message StackTraceElement {
834885
// The fully qualified name of the class containing the execution point.
835886
string declaring_class = 1;
@@ -914,12 +965,6 @@ message FetchErrorDetailsResponse {
914965
// The structured data of a SparkThrowable exception.
915966
optional SparkThrowable spark_throwable = 5;
916967
}
917-
918-
// The index of the root error in errors. The field will not be set if the error is not found.
919-
optional int32 root_error_idx = 1;
920-
921-
// A list of errors.
922-
repeated Error errors = 2;
923968
}
924969

925970
// Main interface for the SparkConnect service.

connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,12 @@ class ArtifactManager(
123123
.setSessionId(sessionId)
124124
.addAllNames(Arrays.asList(artifactName))
125125
.build()
126-
val statuses = bstub.artifactStatus(request).getStatusesMap
126+
val response = bstub.artifactStatus(request)
127+
if (response.getSessionId != sessionId) {
128+
throw new IllegalStateException(
129+
s"Session ID mismatch: $sessionId != ${response.getSessionId}")
130+
}
131+
val statuses = response.getStatusesMap
127132
if (statuses.containsKey(artifactName)) {
128133
statuses.get(artifactName).getExists
129134
} else false
@@ -179,6 +184,9 @@ class ArtifactManager(
179184
val responseHandler = new StreamObserver[proto.AddArtifactsResponse] {
180185
private val summaries = mutable.Buffer.empty[ArtifactSummary]
181186
override def onNext(v: AddArtifactsResponse): Unit = {
187+
if (v.getSessionId != sessionId) {
188+
throw new IllegalStateException(s"Session ID mismatch: $sessionId != ${v.getSessionId}")
189+
}
182190
v.getArtifactsList.forEach { summary =>
183191
summaries += summary
184192
}

connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ import org.apache.spark.connect.proto._
2424

2525
private[connect] class CustomSparkConnectBlockingStub(
2626
channel: ManagedChannel,
27-
retryPolicy: GrpcRetryHandler.RetryPolicy) {
27+
stubState: SparkConnectStubState) {
2828

2929
private val stub = SparkConnectServiceGrpc.newBlockingStub(channel)
3030

31-
private val retryHandler = new GrpcRetryHandler(retryPolicy)
31+
private val retryHandler = stubState.retryHandler
3232

3333
// GrpcExceptionConverter with a GRPC stub for fetching error details from server.
34-
private val grpcExceptionConverter = new GrpcExceptionConverter(stub)
34+
private val grpcExceptionConverter = stubState.exceptionConverter
3535

3636
def executePlan(request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = {
3737
grpcExceptionConverter.convert(
@@ -44,7 +44,10 @@ private[connect] class CustomSparkConnectBlockingStub(
4444
request.getClientType,
4545
retryHandler.RetryIterator[ExecutePlanRequest, ExecutePlanResponse](
4646
request,
47-
r => CloseableIterator(stub.executePlan(r).asScala)))
47+
r => {
48+
stubState.responseValidator.wrapIterator(
49+
CloseableIterator(stub.executePlan(r).asScala))
50+
}))
4851
}
4952
}
5053

@@ -59,7 +62,8 @@ private[connect] class CustomSparkConnectBlockingStub(
5962
request.getUserContext,
6063
request.getClientType,
6164
// Don't use retryHandler - own retry handling is inside.
62-
new ExecutePlanResponseReattachableIterator(request, channel, retryPolicy))
65+
stubState.responseValidator.wrapIterator(
66+
new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryPolicy)))
6367
}
6468
}
6569

@@ -69,7 +73,9 @@ private[connect] class CustomSparkConnectBlockingStub(
6973
request.getUserContext,
7074
request.getClientType) {
7175
retryHandler.retry {
72-
stub.analyzePlan(request)
76+
stubState.responseValidator.verifyResponse {
77+
stub.analyzePlan(request)
78+
}
7379
}
7480
}
7581
}
@@ -80,7 +86,9 @@ private[connect] class CustomSparkConnectBlockingStub(
8086
request.getUserContext,
8187
request.getClientType) {
8288
retryHandler.retry {
83-
stub.config(request)
89+
stubState.responseValidator.verifyResponse {
90+
stub.config(request)
91+
}
8492
}
8593
}
8694
}
@@ -91,7 +99,9 @@ private[connect] class CustomSparkConnectBlockingStub(
9199
request.getUserContext,
92100
request.getClientType) {
93101
retryHandler.retry {
94-
stub.interrupt(request)
102+
stubState.responseValidator.verifyResponse {
103+
stub.interrupt(request)
104+
}
95105
}
96106
}
97107
}
@@ -102,7 +112,9 @@ private[connect] class CustomSparkConnectBlockingStub(
102112
request.getUserContext,
103113
request.getClientType) {
104114
retryHandler.retry {
105-
stub.releaseSession(request)
115+
stubState.responseValidator.verifyResponse {
116+
stub.releaseSession(request)
117+
}
106118
}
107119
}
108120
}
@@ -113,7 +125,9 @@ private[connect] class CustomSparkConnectBlockingStub(
113125
request.getUserContext,
114126
request.getClientType) {
115127
retryHandler.retry {
116-
stub.artifactStatus(request)
128+
stubState.responseValidator.verifyResponse {
129+
stub.artifactStatus(request)
130+
}
117131
}
118132
}
119133
}

connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse
2323

2424
private[client] class CustomSparkConnectStub(
2525
channel: ManagedChannel,
26-
retryPolicy: GrpcRetryHandler.RetryPolicy) {
26+
stubState: SparkConnectStubState) {
2727

2828
private val stub = SparkConnectServiceGrpc.newStub(channel)
29-
private val retryHandler = new GrpcRetryHandler(retryPolicy)
3029

3130
def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse])
3231
: StreamObserver[AddArtifactsRequest] = {
33-
retryHandler.RetryStreamObserver(responseObserver, stub.addArtifacts)
32+
stubState.responseValidator.wrapStreamObserver(
33+
stubState.retryHandler.RetryStreamObserver(responseObserver, stub.addArtifacts))
3434
}
3535
}

connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ class ExecutePlanResponseReattachableIterator(
9797
private[connect] var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] =
9898
Some(rawBlockingStub.executePlan(initialRequest))
9999

100+
// Server side session ID, used to detect if the server side session changed. This is set upon
101+
// receiving the first response from the server.
102+
private var serverSideSessionId: Option[String] = None
103+
100104
override def innerIterator: Iterator[proto.ExecutePlanResponse] = iter match {
101105
case Some(it) => it.asScala
102106
case None =>
@@ -114,10 +118,23 @@ class ExecutePlanResponseReattachableIterator(
114118

115119
try {
116120
// Get next response, possibly triggering reattach in case of stream error.
117-
val ret = retry {
121+
val ret: proto.ExecutePlanResponse = retry {
118122
callIter(_.next())
119123
}
120124

125+
// Check if the server-side session state has changed. If this is the case, immediately
126+
// abandon execution.
127+
serverSideSessionId match {
128+
case Some(id) =>
129+
if (id != ret.getServerSideSessionId) {
130+
throw new IllegalStateException(
131+
s"Server side session ID changed. Create a new SparkSession to continue. " +
132+
s"(Old: $id, New: ${ret.getServerSideSessionId})")
133+
}
134+
case None =>
135+
serverSideSessionId = Some(ret.getServerSideSessionId)
136+
}
137+
121138
// Record last returned response, to know where to restart in case of reattach.
122139
lastReturnedResponseId = Some(ret.getResponseId)
123140
if (ret.hasResultComplete) {

connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@ import scala.jdk.CollectionConverters._
2222
import scala.reflect.ClassTag
2323

2424
import com.google.rpc.ErrorInfo
25-
import io.grpc.StatusRuntimeException
25+
import io.grpc.{ManagedChannel, StatusRuntimeException}
2626
import io.grpc.protobuf.StatusProto
2727
import org.json4s.DefaultFormats
2828
import org.json4s.jackson.JsonMethods
2929

3030
import org.apache.spark.{QueryContext, QueryContextType, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException}
31-
import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, UserContext}
32-
import org.apache.spark.connect.proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub
31+
import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, SparkConnectServiceGrpc, UserContext}
3332
import org.apache.spark.internal.Logging
3433
import org.apache.spark.sql.AnalysisException
3534
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException, TempTableAlreadyExistsException}
@@ -49,10 +48,11 @@ import org.apache.spark.util.ArrayImplicits._
4948
* the ErrorInfo is missing, the exception will be constructed based on the StatusRuntimeException
5049
* itself.
5150
*/
52-
private[client] class GrpcExceptionConverter(grpcStub: SparkConnectServiceBlockingStub)
53-
extends Logging {
51+
private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Logging {
5452
import GrpcExceptionConverter._
5553

54+
val grpcStub = SparkConnectServiceGrpc.newBlockingStub(channel)
55+
5656
def convert[T](sessionId: String, userContext: UserContext, clientType: String)(f: => T): T = {
5757
try {
5858
f

0 commit comments

Comments
 (0)