Skip to content

Commit

Permalink
[SPARK-42585][CONNECT] Streaming of local relations
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
In the PR, I propose to transfer a local relation to the server in streaming way when it exceeds some size which is defined by the SQL config `spark.sql.session.localRelationCacheThreshold`. The config value is 64MB by default. In particular:
1. The client applies the `sha256` function over the arrow form of the local relation;
2. It checks presents of the relation at the server side by sending the relation hash to the server;
3. If the server doesn't have the local relation, the client transfers the local relation as an artefact with the name `cache/<sha256>`;
4. As soon as the relation has presented at the server already, or transferred recently, the client transform the logical plan by replacing the `LocalRelation` node by `CachedLocalRelation` with the hash.
5. On another hand, the server converts `CachedLocalRelation` back to `LocalRelation` by retrieving the relation body from the local cache.

#### Details of the implementation
The client sends new command `ArtifactStatusesRequest` to check either the local relation is cached at the server or not. New command comes via new RPC endpoint `ArtifactStatus`. And the server answers by new message `ArtifactStatusesResponse`, see **base.proto**.

The client transfers serialized (in avro) body of local relation and its schema via the RPC endpoint `AddArtifacts`. On another hand, the server stores the received artifact in the block manager using the id `CacheId`. The last one has 3 parts:
- `userId` - the identifier of the user that created the local relation,
- `sessionId` - the identifier of the session which the relation belongs to,
- `hash` - a `sha-256` hash over relation body.

See **SparkConnectArtifactManager.addArtifact()**.

The current query is blocked till the local relation is cached at the server side.

When the server receives the query, it retrieves `userId`, `sessionId` and `hash` from `CachedLocalRelation`, and gets the local relation data from the block manager. See **SparkConnectPlanner.transformCachedLocalRelation()**.

The occupied blocks at the block manager are removed when an user session is invalidated in `userSessionMapping`. See **SparkConnectService.RemoveSessionListener** and **BlockManager.removeCache()`**.

### Why are the changes needed?
To allow creating a dataframe from a large local collection. `spark.createDataFrame(...)` fails with the following error w/o the changes:
```java
23/04/21 20:32:20 WARN NettyServerStream: Exception processing message
org.sparkproject.connect.grpc.StatusRuntimeException: RESOURCE_EXHAUSTED: gRPC message exceeds maximum size 134217728: 268435456
	at org.sparkproject.connect.grpc.Status.asRuntimeException(Status.java:526)
```

### Does this PR introduce _any_ user-facing change?
No. The changes extend the existing proto API.

### How was this patch tested?
By running the new tests:
```
$ build/sbt "test:testOnly *.ArtifactManagerSuite"
$ build/sbt "test:testOnly *.ClientE2ETestSuite"
$ build/sbt "test:testOnly *.ArtifactStatusesHandlerSuite"
```

Closes #40827 from MaxGekk/streaming-createDataFrame-2.

Authored-by: Max Gekk <max.gekk@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
MaxGekk authored and HyukjinKwon committed May 2, 2023
1 parent d26292c commit 0d7618a
Show file tree
Hide file tree
Showing 23 changed files with 922 additions and 208 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,23 @@ class SparkSession private[sql] (

private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = {
newDataset(encoder) { builder =>
val localRelationBuilder = builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
if (data.nonEmpty) {
val timeZoneId = conf.get("spark.sql.session.timeZone")
val arrowData = ConvertToArrow(encoder, data, timeZoneId, allocator)
localRelationBuilder.setData(arrowData)
val (arrowData, arrowDataSize) = ConvertToArrow(encoder, data, timeZoneId, allocator)
if (arrowDataSize <= conf.get("spark.sql.session.localRelationCacheThreshold").toInt) {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
.setData(arrowData)
} else {
val hash = client.cacheLocalRelation(arrowDataSize, arrowData, encoder.schema.json)
builder.getCachedLocalRelationBuilder
.setUserId(client.userId)
.setSessionId(client.sessionId)
.setHash(hash)
}
} else {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connect.client
import java.io.{ByteArrayInputStream, InputStream}
import java.net.URI
import java.nio.file.{Files, Path, Paths}
import java.util.Arrays
import java.util.concurrent.CopyOnWriteArrayList
import java.util.zip.{CheckedInputStream, CRC32}

Expand All @@ -32,6 +33,7 @@ import Artifact._
import com.google.protobuf.ByteString
import io.grpc.ManagedChannel
import io.grpc.stub.StreamObserver
import org.apache.commons.codec.digest.DigestUtils.sha256Hex

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.AddArtifactsResponse
Expand All @@ -42,14 +44,20 @@ import org.apache.spark.util.{ThreadUtils, Utils}
* The Artifact Manager is responsible for handling and transferring artifacts from the local
* client to the server (local/remote).
* @param userContext
* @param sessionId
* An unique identifier of the session which the artifact manager belongs to.
* @param channel
*/
class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
class ArtifactManager(
userContext: proto.UserContext,
sessionId: String,
channel: ManagedChannel) {
// Using the midpoint recommendation of 32KiB for chunk size as specified in
// https://github.com/grpc/grpc.github.io/issues/371.
private val CHUNK_SIZE: Int = 32 * 1024

private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel)
private[this] val bstub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder]

/**
Expand Down Expand Up @@ -100,6 +108,31 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
*/
def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts))

private def isCachedArtifact(hash: String): Boolean = {
val artifactName = CACHE_PREFIX + "/" + hash
val request = proto.ArtifactStatusesRequest
.newBuilder()
.setUserContext(userContext)
.setSessionId(sessionId)
.addAllNames(Arrays.asList(artifactName))
.build()
val statuses = bstub.artifactStatus(request).getStatusesMap
if (statuses.containsKey(artifactName)) {
statuses.get(artifactName).getExists
} else false
}

/**
* Cache the give blob at the session.
*/
def cacheArtifact(blob: Array[Byte]): String = {
val hash = sha256Hex(blob)
if (!isCachedArtifact(hash)) {
addArtifacts(newCacheArtifact(hash, new InMemory(blob)) :: Nil)
}
hash
}

/**
* Upload all class file artifacts from the local REPL(s) to the server.
*
Expand Down Expand Up @@ -182,6 +215,7 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
val builder = proto.AddArtifactsRequest
.newBuilder()
.setUserContext(userContext)
.setSessionId(sessionId)
artifacts.foreach { artifact =>
val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
try {
Expand Down Expand Up @@ -236,6 +270,7 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
val builder = proto.AddArtifactsRequest
.newBuilder()
.setUserContext(userContext)
.setSessionId(sessionId)

val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
try {
Expand Down Expand Up @@ -289,6 +324,7 @@ class Artifact private (val path: Path, val storage: LocalData) {
object Artifact {
val CLASS_PREFIX: Path = Paths.get("classes")
val JAR_PREFIX: Path = Paths.get("jars")
val CACHE_PREFIX: Path = Paths.get("cache")

def newJarArtifact(fileName: Path, storage: LocalData): Artifact = {
newArtifact(JAR_PREFIX, ".jar", fileName, storage)
Expand All @@ -298,6 +334,10 @@ object Artifact {
newArtifact(CLASS_PREFIX, ".class", fileName, storage)
}

def newCacheArtifact(id: String, storage: LocalData): Artifact = {
newArtifact(CACHE_PREFIX, "", Paths.get(id), storage)
}

private def newArtifact(
prefix: Path,
requiredSuffix: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql.connect.client

import com.google.protobuf.ByteString
import io.grpc.{CallCredentials, CallOptions, Channel, ClientCall, ClientInterceptor, CompositeChannelCredentials, ForwardingClientCall, Grpc, InsecureChannelCredentials, ManagedChannel, ManagedChannelBuilder, Metadata, MethodDescriptor, Status, TlsChannelCredentials}
import java.net.URI
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util.UUID
import java.util.concurrent.Executor
import scala.language.existentials
Expand All @@ -39,19 +42,21 @@ private[sql] class SparkConnectClient(

private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)

private[client] val artifactManager: ArtifactManager = new ArtifactManager(userContext, channel)

/**
* Placeholder method.
* @return
* User ID.
*/
private[client] def userId: String = userContext.getUserId()
private[sql] def userId: String = userContext.getUserId()

// Generate a unique session ID for this client. This UUID must be unique to allow
// concurrent Spark sessions of the same user. If the channel is closed, creating
// a new client will create a new session ID.
private[client] val sessionId: String = UUID.randomUUID.toString
private[sql] val sessionId: String = UUID.randomUUID.toString

private[client] val artifactManager: ArtifactManager = {
new ArtifactManager(userContext, sessionId, channel)
}

/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
Expand Down Expand Up @@ -215,6 +220,19 @@ private[sql] class SparkConnectClient(
def shutdown(): Unit = {
channel.shutdownNow()
}

/**
* Cache the given local relation at the server, and return its key in the remote cache.
*/
def cacheLocalRelation(size: Int, data: ByteString, schema: String): String = {
val schemaBytes = schema.getBytes(StandardCharsets.UTF_8)
val locRelData = data.toByteArray
val locRel = ByteBuffer.allocate(4 + locRelData.length + schemaBytes.length)
locRel.putInt(size)
locRel.put(locRelData)
locRel.put(schemaBytes)
artifactManager.cacheArtifact(locRel.array())
}
}

object SparkConnectClient {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ import org.apache.spark.sql.util.ArrowUtils
private[sql] object ConvertToArrow {

/**
* Convert an iterator of common Scala objects into a sinlge Arrow IPC Stream.
* Convert an iterator of common Scala objects into a single Arrow IPC Stream.
*/
def apply[T](
encoder: AgnosticEncoder[T],
data: Iterator[T],
timeZoneId: String,
bufferAllocator: BufferAllocator): ByteString = {
bufferAllocator: BufferAllocator): (ByteString, Int) = {
val arrowSchema = ArrowUtils.toArrowSchema(encoder.schema, timeZoneId)
val root = VectorSchemaRoot.create(arrowSchema, bufferAllocator)
val writer: ArrowWriter = ArrowWriter.create(root)
Expand All @@ -64,7 +64,7 @@ private[sql] object ConvertToArrow {
ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT)

// Done
bytes.toByteString
(bytes.toByteString, bytes.size)
} finally {
root.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
Expand Down Expand Up @@ -853,6 +854,19 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
}.getMessage
assert(message.contains("PARSE_SYNTAX_ERROR"))
}

test("SparkSession.createDataFrame - large data set") {
val threshold = 1024 * 1024
withSQLConf(SQLConf.LOCAL_RELATION_CACHE_THRESHOLD.key -> threshold.toString) {
val count = 2
val suffix = "abcdef"
val str = scala.util.Random.alphanumeric.take(1024 * 1024).mkString + suffix
val data = Seq.tabulate(count)(i => (i, str))
val df = spark.createDataFrame(data)
assert(df.count() === count)
assert(!df.filter(df("_2").endsWith(suffix)).isEmpty)
}
}
}

private[sql] case class MyType(id: Long, a: Double, b: Double)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {

private def createArtifactManager(): Unit = {
channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build()
artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), channel)
artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), "", channel)
}

override def beforeEach(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,43 @@ message AddArtifactsResponse {
repeated ArtifactSummary artifacts = 1;
}

// Request to get current statuses of artifacts at the server side.
message ArtifactStatusesRequest {
// (Required)
//
// The session_id specifies a spark session for a user id (which is specified
// by user_context.user_id). The session_id is set by the client to be able to
// collate streaming responses from different queries within the dedicated session.
string session_id = 1;

// User context
UserContext user_context = 2;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 3;

// The name of the artifact is expected in the form of a "Relative Path" that is made up of a
// sequence of directories and the final file element.
// Examples of "Relative Path"s: "jars/test.jar", "classes/xyz.class", "abc.xyz", "a/b/X.jar".
// The server is expected to maintain the hierarchy of files as defined by their name. (i.e
// The relative path of the file on the server's filesystem will be the same as the name of
// the provided artifact)
repeated string names = 4;
}

// Response to checking artifact statuses.
message ArtifactStatusesResponse {
message ArtifactStatus {
// Exists or not particular artifact at the server.
bool exists = 1;
}

// A map of artifact names to their statuses.
map<string, ArtifactStatus> statuses = 1;
}

// Main interface for the SparkConnect service.
service SparkConnectService {

Expand All @@ -559,5 +596,8 @@ service SparkConnectService {
// Add artifacts to the session and returns a [[AddArtifactsResponse]] containing metadata about
// the added artifacts.
rpc AddArtifacts(stream AddArtifactsRequest) returns (AddArtifactsResponse) {}

// Check statuses of artifacts in the session and returns them in a [[ArtifactStatusesResponse]]
rpc ArtifactStatus(ArtifactStatusesRequest) returns (ArtifactStatusesResponse) {}
}

Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ message Relation {
WithWatermark with_watermark = 33;
ApplyInPandasWithState apply_in_pandas_with_state = 34;
HtmlString html_string = 35;
CachedLocalRelation cached_local_relation = 36;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -381,6 +382,18 @@ message LocalRelation {
optional string schema = 2;
}

// A local relation that has been cached already.
message CachedLocalRelation {
// (Required) An identifier of the user which created the local relation
string userId = 1;

// (Required) An identifier of the Spark SQL session in which the user created the local relation.
string sessionId = 2;

// (Required) A sha-256 hash of the serialized local relation.
string hash = 3;
}

// Relation of type [[Sample]] that samples a fraction of the dataset.
message Sample {
// (Required) Input relation for a Sample.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import java.nio.file.{Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.storage.{CacheId, StorageLevel}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -87,11 +89,28 @@ class SparkConnectArtifactManager private[connect] {
* @param serverLocalStagingPath
*/
private[connect] def addArtifact(
session: SparkSession,
sessionHolder: SessionHolder,
remoteRelativePath: Path,
serverLocalStagingPath: Path): Unit = {
require(!remoteRelativePath.isAbsolute)
if (remoteRelativePath.startsWith("classes/")) {
if (remoteRelativePath.startsWith("cache/")) {
val tmpFile = serverLocalStagingPath.toFile
Utils.tryWithSafeFinallyAndFailureCallbacks {
val blockManager = sessionHolder.session.sparkContext.env.blockManager
val blockId = CacheId(
userId = sessionHolder.userId,
sessionId = sessionHolder.sessionId,
hash = remoteRelativePath.toString.stripPrefix("cache/"))
val updater = blockManager.TempFileBasedBlockStoreUpdater(
blockId = blockId,
level = StorageLevel.MEMORY_AND_DISK_SER,
classTag = implicitly[ClassTag[Array[Byte]]],
tmpFile = tmpFile,
blockSize = tmpFile.length(),
tellMaster = false)
updater.save()
}(catchBlock = { tmpFile.delete() })
} else if (remoteRelativePath.startsWith("classes/")) {
// Move class files to common location (shared among all users)
val target = classArtifactDir.resolve(remoteRelativePath.toString.stripPrefix("classes/"))
Files.createDirectories(target.getParent)
Expand All @@ -110,7 +129,7 @@ class SparkConnectArtifactManager private[connect] {
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith("jars")) {
// Adding Jars to the underlying spark context (visible to all users)
session.sessionState.resourceLoader.addJar(target.toString)
sessionHolder.session.sessionState.resourceLoader.addJar(target.toString)
jarsList.add(target)
}
}
Expand Down

0 comments on commit 0d7618a

Please sign in to comment.