Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 22 additions & 64 deletions Sources/SparkConnect/DataFrame.swift
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,24 @@ public actor DataFrame: Sendable {
}
}

private func analyzePlanIfNeeded() async throws {
if self._schema != nil {
return
}
private func withGPRC<Result: Sendable>(
_ f: (GRPCClient<GRPCNIOTransportHTTP2.HTTP2ClientTransport.Posix>) async throws -> Result
) async throws -> Result {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
return try await f(client)
}
}

private func analyzePlanIfNeeded() async throws {
if self._schema != nil {
return
}
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.analyzePlan(
spark.client.getAnalyzePlanRequest(spark.sessionID, plan))
Expand All @@ -132,12 +140,7 @@ public actor DataFrame: Sendable {
public func count() async throws -> Int64 {
let counter = Atomic(Int64(0))

try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
try await service.executePlan(spark.client.getExecutePlanRequest(plan)) {
response in
Expand All @@ -151,12 +154,7 @@ public actor DataFrame: Sendable {

/// Execute the plan and try to fill `schema` and `batches`.
private func execute() async throws {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
try await service.executePlan(spark.client.getExecutePlanRequest(plan)) {
response in
Expand Down Expand Up @@ -394,12 +392,7 @@ public actor DataFrame: Sendable {
/// (without any Spark executors).
/// - Returns: True if the plan is local.
public func isLocal() async throws -> Bool {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.analyzePlan(spark.client.getIsLocal(spark.sessionID, plan))
return response.isLocal.isLocal
Expand All @@ -410,12 +403,7 @@ public actor DataFrame: Sendable {
/// arrives.
/// - Returns: True if a plan is streaming.
public func isStreaming() async throws -> Bool {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.analyzePlan(spark.client.getIsStreaming(spark.sessionID, plan))
return response.isStreaming.isStreaming
Expand All @@ -439,12 +427,7 @@ public actor DataFrame: Sendable {
public func persist(storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK) async throws
-> DataFrame
{
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
_ = try await service.analyzePlan(
spark.client.getPersist(spark.sessionID, plan, storageLevel))
Expand All @@ -458,12 +441,7 @@ public actor DataFrame: Sendable {
/// - Parameter blocking: Whether to block until all blocks are deleted.
/// - Returns: A `DataFrame`
public func unpersist(blocking: Bool = false) async throws -> DataFrame {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
_ = try await service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking))
}
Expand All @@ -473,12 +451,7 @@ public actor DataFrame: Sendable {

public var storageLevel: StorageLevel {
get async throws {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
return try await service
.analyzePlan(spark.client.getStorageLevel(spark.sessionID, plan)).getStorageLevel.storageLevel.toStorageLevel
Expand All @@ -505,12 +478,7 @@ public actor DataFrame: Sendable {
/// - Parameter mode: the expected output format of plans;
/// `simple`, `extended`, `codegen`, `cost`, `formatted`.
public func explain(_ mode: String) async throws {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.analyzePlan(spark.client.getExplain(spark.sessionID, plan, mode))
print(response.explain.explainString)
Expand All @@ -522,12 +490,7 @@ public actor DataFrame: Sendable {
/// results. Depending on the source relations, this may not find all input files. Duplicates are removed.
/// - Returns: An array of file path strings.
public func inputFiles() async throws -> [String] {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.analyzePlan(spark.client.getInputFiles(spark.sessionID, plan))
return response.inputFiles.files
Expand All @@ -542,12 +505,7 @@ public actor DataFrame: Sendable {
/// Prints the schema up to the given level to the console in a nice tree format.
/// - Parameter level: A level to be printed.
public func printSchema(_ level: Int32) async throws {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: spark.client.host, port: spark.client.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = Spark_Connect_SparkConnectService.Client(wrapping: client)
let response = try await service.analyzePlan(spark.client.getTreeString(spark.sessionID, plan, level))
print(response.treeString.treeString)
Expand Down
69 changes: 21 additions & 48 deletions Sources/SparkConnect/SparkConnectClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,7 @@ public actor SparkConnectClient {
/// - Parameter sessionID: A string for the session ID.
/// - Returns: An `AnalyzePlanResponse` instance for `SparkVersion`
func connect(_ sessionID: String) async throws -> AnalyzePlanResponse {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
// To prevent server-side `INVALID_HANDLE.FORMAT (SQLSTATE: HY000)` exception.
if UUID(uuidString: sessionID) == nil {
throw SparkConnectError.InvalidSessionIDException
Expand All @@ -73,6 +68,19 @@ public actor SparkConnectClient {
}
}

private func withGPRC<Result: Sendable>(
_ f: (GRPCClient<GRPCNIOTransportHTTP2.HTTP2ClientTransport.Posix>) async throws -> Result
) async throws -> Result {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
return try await f(client)
}
}

/// Create a ``ConfigRequest`` instance for `Set` operation.
/// - Parameter map: A map of key-value string pairs.
/// - Returns: A ``ConfigRequest`` instance.
Expand All @@ -89,12 +97,7 @@ public actor SparkConnectClient {
/// - Parameter map: A map of key-value pairs to set.
/// - Returns: Always return true.
func setConf(map: [String: String]) async throws -> Bool {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
var request = getConfigRequestSet(map: map)
request.clientType = clientType
Expand All @@ -118,12 +121,7 @@ public actor SparkConnectClient {
}

func unsetConf(keys: [String]) async throws -> Bool {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
var request = getConfigRequestUnset(keys: keys)
request.clientType = clientType
Expand All @@ -150,12 +148,7 @@ public actor SparkConnectClient {
/// - Parameter key: A string for key to look up.
/// - Returns: A string for the value of the key.
func getConf(_ key: String) async throws -> String {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
var request = getConfigRequestGet(keys: [key])
request.clientType = clientType
Expand All @@ -179,12 +172,7 @@ public actor SparkConnectClient {
/// Request the server to get all configurations.
/// - Returns: A map of key-value pairs.
func getConfAll() async throws -> [String: String] {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
var request = getConfigRequestGetAll()
request.clientType = clientType
Expand Down Expand Up @@ -451,12 +439,7 @@ public actor SparkConnectClient {

func execute(_ sessionID: String, _ command: Command) async throws -> [ExecutePlanResponse] {
self.result.removeAll()
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
var plan = Plan()
plan.opType = .command(command)
Expand Down Expand Up @@ -501,12 +484,7 @@ public actor SparkConnectClient {
/// - Parameter ddlString: A string to parse.
/// - Returns: A ``Spark_Connect_DataType`` instance.
func ddlParse(_ ddlString: String) async throws -> Spark_Connect_DataType {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
let request = analyze(self.sessionID!, {
var ddlParse = AnalyzePlanRequest.DDLParse()
Expand All @@ -522,12 +500,7 @@ public actor SparkConnectClient {
/// - Parameter jsonString: A JSON string.
/// - Returns: A DDL string.
func jsonToDdl(_ jsonString: String) async throws -> String {
try await withGRPCClient(
transport: .http2NIOPosix(
target: .dns(host: self.host, port: self.port),
transportSecurity: .plaintext
)
) { client in
try await withGPRC { client in
let service = SparkConnectService.Client(wrapping: client)
let request = analyze(self.sessionID!, {
var jsonToDDL = AnalyzePlanRequest.JsonToDDL()
Expand Down
Loading