Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-42585][CONNECT] Streaming of local relations #40827

Closed
wants to merge 61 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
8f4c0d0
Add an end-to-end test
MaxGekk Apr 14, 2023
b43e79c
Add the SQL config spark.sql.session.localRelationCacheThreshold
MaxGekk Apr 15, 2023
d89d004
Update createDataset()
MaxGekk Apr 15, 2023
4adee82
Add the proto message: CachedLocalRelation
MaxGekk Apr 17, 2023
33a2baf
Re-gen relations_pb2.py and relations_pb2.pyi
MaxGekk Apr 17, 2023
55bfac5
Merge remote-tracking branch 'origin/master' into streaming-createDat…
MaxGekk Apr 18, 2023
0a4639e
Re-gen relations_pb2.py and relations_pb2.pyi
MaxGekk Apr 18, 2023
a78b7c0
Re-gen golden files for PlanGenerationTestSuite
MaxGekk Apr 18, 2023
3f22c9f
Re-gen golden files
MaxGekk Apr 18, 2023
74b409e
Send the local relation to the remote cache
MaxGekk Apr 18, 2023
526e696
Partial implementation
MaxGekk Apr 19, 2023
c8b0bbc
Impl w/o exist checks
MaxGekk Apr 19, 2023
25c77ce
Bug fixes
MaxGekk Apr 19, 2023
f603081
Merge remote-tracking branch 'origin/master' into streaming-createDat…
MaxGekk Apr 20, 2023
b37b595
Re-gen relations_pb2.py
MaxGekk Apr 20, 2023
88b0e47
Bug fix: serialization
MaxGekk Apr 20, 2023
09f6fc5
Add a test to ArtifactManagerSuite
MaxGekk Apr 20, 2023
33b1650
Add a test to ClientE2ETestSuite
MaxGekk Apr 20, 2023
5932218
Trigger build
MaxGekk Apr 20, 2023
d369cf4
Reformat SparkConnectArtifactManager and SparkConnectPlanner
MaxGekk Apr 20, 2023
ba72916
Trigger build
MaxGekk Apr 21, 2023
3d8d157
Trigger build
MaxGekk Apr 21, 2023
856b5e7
Merge remote-tracking branch 'origin/master' into streaming-createDat…
MaxGekk Apr 22, 2023
bbc2632
Re-gen proto/relations_pb2.py
MaxGekk Apr 22, 2023
7af7bac
Add the rpc endpoint: ArtifactStatus
MaxGekk Apr 24, 2023
e6272b6
Re-gen proto/base_pb2
MaxGekk Apr 24, 2023
eec35b1
Implement SparkConnectArtifactStatusesHandler
MaxGekk Apr 24, 2023
d88d68f
Fix compilation errors
MaxGekk Apr 24, 2023
98b42e5
Merge remote-tracking branch 'origin/master' into streaming-createDat…
MaxGekk Apr 24, 2023
b6c9a9f
Fix coding style and compiler errors
MaxGekk Apr 24, 2023
11e4f42
Add the test suite: ArtifactStatusesHandlerSuite
MaxGekk Apr 24, 2023
3e18eaf
Bug fix: artifact has a prefix
MaxGekk Apr 25, 2023
73d0236
Is the artifact cached: check from the client.
MaxGekk Apr 25, 2023
d4f603a
Merge remote-tracking branch 'origin/master' into streaming-createDat…
MaxGekk Apr 25, 2023
e99bf32
Re-gen base_pb2.py
MaxGekk Apr 25, 2023
d37a24d
Add user and session id to BlockId
MaxGekk Apr 25, 2023
e11bd47
Re-gen relations_pb2
MaxGekk Apr 25, 2023
7c03387
Reformat connect files
MaxGekk Apr 25, 2023
f707659
Merge remote-tracking branch 'origin/master' into streaming-createDat…
MaxGekk Apr 26, 2023
6e86b1b
Re-gen relations_pb2
MaxGekk Apr 26, 2023
e1a5f70
Clean up cache entries in the block manager.
MaxGekk Apr 26, 2023
db404f4
Fix scala 2.13 errors
MaxGekk Apr 26, 2023
1e93ba4
Trigger build
MaxGekk Apr 26, 2023
68d1144
Trigger build
MaxGekk Apr 27, 2023
f73730a
Use Arrays.asList
MaxGekk Apr 27, 2023
9cf03de
Address Wenchen's review comments
MaxGekk Apr 27, 2023
be76d83
Re-gen base_pb2
MaxGekk Apr 27, 2023
e1cc5f6
Trigger build
MaxGekk Apr 28, 2023
5952148
Trigger build
MaxGekk Apr 28, 2023
b7bbad3
Implement remove cache in the unified way
MaxGekk Apr 28, 2023
ea0cc60
Remove a duplicate test
MaxGekk Apr 28, 2023
22841da
Merge remote-tracking branch 'origin/master' into streaming-createDat…
MaxGekk Apr 28, 2023
49c0c10
Revert "Remove a duplicate test"
MaxGekk Apr 28, 2023
0e87751
Don't tell to the master about caches
MaxGekk Apr 28, 2023
61e8e2e
Release the read lock of the caches
MaxGekk Apr 28, 2023
7dc4d1c
Make the artifact test more stable
MaxGekk Apr 28, 2023
1ec8fb5
Reformat SparkConnectPlanner
MaxGekk Apr 29, 2023
900867e
Merge remote-tracking branch 'origin/master' into streaming-createDat…
MaxGekk Apr 29, 2023
710d224
Reformat SparkConnectService
MaxGekk Apr 29, 2023
46eb354
Remove the dependency of SQLConf
MaxGekk May 1, 2023
2351069
Trigger build
MaxGekk May 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,23 @@ class SparkSession private[sql] (

private def createDataset[T](encoder: AgnosticEncoder[T], data: Iterator[T]): Dataset[T] = {
newDataset(encoder) { builder =>
val localRelationBuilder = builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
if (data.nonEmpty) {
val timeZoneId = conf.get("spark.sql.session.timeZone")
val arrowData = ConvertToArrow(encoder, data, timeZoneId, allocator)
localRelationBuilder.setData(arrowData)
val (arrowData, arrowDataSize) = ConvertToArrow(encoder, data, timeZoneId, allocator)
if (arrowDataSize <= conf.get("spark.sql.session.localRelationCacheThreshold").toInt) {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
.setData(arrowData)
} else {
val hash = client.cacheLocalRelation(arrowDataSize, arrowData, encoder.schema.json)
builder.getCachedLocalRelationBuilder
.setUserId(client.userId)
.setSessionId(client.sessionId)
.setHash(hash)
}
} else {
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.connect.client
import java.io.{ByteArrayInputStream, InputStream}
import java.net.URI
import java.nio.file.{Files, Path, Paths}
import java.util.Arrays
import java.util.concurrent.CopyOnWriteArrayList
import java.util.zip.{CheckedInputStream, CRC32}

Expand All @@ -32,6 +33,7 @@ import Artifact._
import com.google.protobuf.ByteString
import io.grpc.ManagedChannel
import io.grpc.stub.StreamObserver
import org.apache.commons.codec.digest.DigestUtils.sha256Hex

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.AddArtifactsResponse
Expand All @@ -42,14 +44,20 @@ import org.apache.spark.util.{ThreadUtils, Utils}
* The Artifact Manager is responsible for handling and transferring artifacts from the local
* client to the server (local/remote).
* @param userContext
* @param sessionId
* An unique identifier of the session which the artifact manager belongs to.
* @param channel
*/
class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
class ArtifactManager(
userContext: proto.UserContext,
sessionId: String,
channel: ManagedChannel) {
// Using the midpoint recommendation of 32KiB for chunk size as specified in
// https://github.com/grpc/grpc.github.io/issues/371.
private val CHUNK_SIZE: Int = 32 * 1024

private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel)
private[this] val bstub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)
private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder]

/**
Expand Down Expand Up @@ -100,6 +108,31 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
*/
def addArtifacts(uris: Seq[URI]): Unit = addArtifacts(uris.flatMap(parseArtifacts))

private def isCachedArtifact(hash: String): Boolean = {
val artifactName = CACHE_PREFIX + "/" + hash
val request = proto.ArtifactStatusesRequest
.newBuilder()
.setUserContext(userContext)
.setSessionId(sessionId)
.addAllNames(Arrays.asList(artifactName))
.build()
val statuses = bstub.artifactStatus(request).getStatusesMap
if (statuses.containsKey(artifactName)) {
statuses.get(artifactName).getExists
} else false
}

/**
* Cache the give blob at the session.
*/
def cacheArtifact(blob: Array[Byte]): String = {
val hash = sha256Hex(blob)
if (!isCachedArtifact(hash)) {
addArtifacts(newCacheArtifact(hash, new InMemory(blob)) :: Nil)
}
hash
}

/**
* Upload all class file artifacts from the local REPL(s) to the server.
*
Expand Down Expand Up @@ -182,6 +215,7 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
val builder = proto.AddArtifactsRequest
.newBuilder()
.setUserContext(userContext)
.setSessionId(sessionId)
artifacts.foreach { artifact =>
val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
try {
Expand Down Expand Up @@ -236,6 +270,7 @@ class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel) {
val builder = proto.AddArtifactsRequest
.newBuilder()
.setUserContext(userContext)
.setSessionId(sessionId)

val in = new CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
try {
Expand Down Expand Up @@ -289,6 +324,7 @@ class Artifact private (val path: Path, val storage: LocalData) {
object Artifact {
val CLASS_PREFIX: Path = Paths.get("classes")
val JAR_PREFIX: Path = Paths.get("jars")
val CACHE_PREFIX: Path = Paths.get("cache")

def newJarArtifact(fileName: Path, storage: LocalData): Artifact = {
newArtifact(JAR_PREFIX, ".jar", fileName, storage)
Expand All @@ -298,6 +334,10 @@ object Artifact {
newArtifact(CLASS_PREFIX, ".class", fileName, storage)
}

def newCacheArtifact(id: String, storage: LocalData): Artifact = {
newArtifact(CACHE_PREFIX, "", Paths.get(id), storage)
}

private def newArtifact(
prefix: Path,
requiredSuffix: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.spark.sql.connect.client

import com.google.protobuf.ByteString
import io.grpc.{CallCredentials, CallOptions, Channel, ClientCall, ClientInterceptor, CompositeChannelCredentials, ForwardingClientCall, Grpc, InsecureChannelCredentials, ManagedChannel, ManagedChannelBuilder, Metadata, MethodDescriptor, Status, TlsChannelCredentials}
import java.net.URI
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util.UUID
import java.util.concurrent.Executor
import scala.language.existentials
Expand All @@ -39,19 +42,21 @@ private[sql] class SparkConnectClient(

private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel)

private[client] val artifactManager: ArtifactManager = new ArtifactManager(userContext, channel)

/**
* Placeholder method.
* @return
* User ID.
*/
private[client] def userId: String = userContext.getUserId()
private[sql] def userId: String = userContext.getUserId()

// Generate a unique session ID for this client. This UUID must be unique to allow
// concurrent Spark sessions of the same user. If the channel is closed, creating
// a new client will create a new session ID.
private[client] val sessionId: String = UUID.randomUUID.toString
private[sql] val sessionId: String = UUID.randomUUID.toString

private[client] val artifactManager: ArtifactManager = {
new ArtifactManager(userContext, sessionId, channel)
}

/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
Expand Down Expand Up @@ -215,6 +220,19 @@ private[sql] class SparkConnectClient(
def shutdown(): Unit = {
channel.shutdownNow()
}

/**
* Cache the given local relation at the server, and return its key in the remote cache.
*/
def cacheLocalRelation(size: Int, data: ByteString, schema: String): String = {
val schemaBytes = schema.getBytes(StandardCharsets.UTF_8)
val locRelData = data.toByteArray
val locRel = ByteBuffer.allocate(4 + locRelData.length + schemaBytes.length)
locRel.putInt(size)
locRel.put(locRelData)
locRel.put(schemaBytes)
artifactManager.cacheArtifact(locRel.array())
}
}

object SparkConnectClient {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ import org.apache.spark.sql.util.ArrowUtils
private[sql] object ConvertToArrow {

/**
* Convert an iterator of common Scala objects into a sinlge Arrow IPC Stream.
* Convert an iterator of common Scala objects into a single Arrow IPC Stream.
*/
def apply[T](
encoder: AgnosticEncoder[T],
data: Iterator[T],
timeZoneId: String,
bufferAllocator: BufferAllocator): ByteString = {
bufferAllocator: BufferAllocator): (ByteString, Int) = {
val arrowSchema = ArrowUtils.toArrowSchema(encoder.schema, timeZoneId)
val root = VectorSchemaRoot.create(arrowSchema, bufferAllocator)
val writer: ArrowWriter = ArrowWriter.create(root)
Expand All @@ -64,7 +64,7 @@ private[sql] object ConvertToArrow {
ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT)

// Done
bytes.toByteString
(bytes.toByteString, bytes.size)
} finally {
root.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
Expand Down Expand Up @@ -853,6 +854,19 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
}.getMessage
assert(message.contains("PARSE_SYNTAX_ERROR"))
}

test("SparkSession.createDataFrame - large data set") {
val threshold = 1024 * 1024
withSQLConf(SQLConf.LOCAL_RELATION_CACHE_THRESHOLD.key -> threshold.toString) {
val count = 2
val suffix = "abcdef"
val str = scala.util.Random.alphanumeric.take(1024 * 1024).mkString + suffix
val data = Seq.tabulate(count)(i => (i, str))
val df = spark.createDataFrame(data)
assert(df.count() === count)
assert(!df.filter(df("_2").endsWith(suffix)).isEmpty)
}
}
}

private[sql] case class MyType(id: Long, a: Double, b: Double)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {

private def createArtifactManager(): Unit = {
channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build()
artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), channel)
artifactManager = new ArtifactManager(proto.UserContext.newBuilder().build(), "", channel)
}

override def beforeEach(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,43 @@ message AddArtifactsResponse {
repeated ArtifactSummary artifacts = 1;
}

// Request to get current statuses of artifacts at the server side.
message ArtifactStatusesRequest {
// (Required)
//
// The session_id specifies a spark session for a user id (which is specified
// by user_context.user_id). The session_id is set by the client to be able to
// collate streaming responses from different queries within the dedicated session.
string session_id = 1;

// User context
UserContext user_context = 2;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 3;

// The name of the artifact is expected in the form of a "Relative Path" that is made up of a
MaxGekk marked this conversation as resolved.
Show resolved Hide resolved
// sequence of directories and the final file element.
// Examples of "Relative Path"s: "jars/test.jar", "classes/xyz.class", "abc.xyz", "a/b/X.jar".
// The server is expected to maintain the hierarchy of files as defined by their name. (i.e
// The relative path of the file on the server's filesystem will be the same as the name of
// the provided artifact)
repeated string names = 4;
}

// Response to checking artifact statuses.
message ArtifactStatusesResponse {
message ArtifactStatus {
// Exists or not particular artifact at the server.
bool exists = 1;
}

// A map of artifact names to their statuses.
map<string, ArtifactStatus> statuses = 1;
}

// Main interface for the SparkConnect service.
service SparkConnectService {

Expand All @@ -559,5 +596,8 @@ service SparkConnectService {
// Add artifacts to the session and returns a [[AddArtifactsResponse]] containing metadata about
// the added artifacts.
rpc AddArtifacts(stream AddArtifactsRequest) returns (AddArtifactsResponse) {}

// Check statuses of artifacts in the session and returns them in a [[ArtifactStatusesResponse]]
rpc ArtifactStatus(ArtifactStatusesRequest) returns (ArtifactStatusesResponse) {}
}

Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ message Relation {
WithWatermark with_watermark = 33;
ApplyInPandasWithState apply_in_pandas_with_state = 34;
HtmlString html_string = 35;
CachedLocalRelation cached_local_relation = 36;

// NA functions
NAFill fill_na = 90;
Expand Down Expand Up @@ -381,6 +382,18 @@ message LocalRelation {
optional string schema = 2;
}

// A local relation that has been cached already.
message CachedLocalRelation {
// (Required) An identifier of the user which created the local relation
string userId = 1;

// (Required) An identifier of the Spark SQL session in which the user created the local relation.
string sessionId = 2;

// (Required) A sha-256 hash of the serialized local relation.
string hash = 3;
}

// Relation of type [[Sample]] that samples a fraction of the dataset.
message Sample {
// (Required) Input relation for a Sample.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import java.nio.file.{Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.storage.{CacheId, StorageLevel}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -87,11 +89,28 @@ class SparkConnectArtifactManager private[connect] {
* @param serverLocalStagingPath
*/
private[connect] def addArtifact(
session: SparkSession,
sessionHolder: SessionHolder,
remoteRelativePath: Path,
serverLocalStagingPath: Path): Unit = {
require(!remoteRelativePath.isAbsolute)
if (remoteRelativePath.startsWith("classes/")) {
if (remoteRelativePath.startsWith("cache/")) {
val tmpFile = serverLocalStagingPath.toFile
Utils.tryWithSafeFinallyAndFailureCallbacks {
val blockManager = sessionHolder.session.sparkContext.env.blockManager
val blockId = CacheId(
userId = sessionHolder.userId,
sessionId = sessionHolder.sessionId,
hash = remoteRelativePath.toString.stripPrefix("cache/"))
val updater = blockManager.TempFileBasedBlockStoreUpdater(
blockId = blockId,
level = StorageLevel.MEMORY_AND_DISK_SER,
classTag = implicitly[ClassTag[Array[Byte]]],
tmpFile = tmpFile,
blockSize = tmpFile.length(),
tellMaster = false)
updater.save()
}(catchBlock = { tmpFile.delete() })
} else if (remoteRelativePath.startsWith("classes/")) {
// Move class files to common location (shared among all users)
val target = classArtifactDir.resolve(remoteRelativePath.toString.stripPrefix("classes/"))
Files.createDirectories(target.getParent)
Expand All @@ -110,7 +129,7 @@ class SparkConnectArtifactManager private[connect] {
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith("jars")) {
// Adding Jars to the underlying spark context (visible to all users)
session.sessionState.resourceLoader.addJar(target.toString)
sessionHolder.session.sessionState.resourceLoader.addJar(target.toString)
jarsList.add(target)
}
}
Expand Down