diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index d373a5b..9ee8dbd 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -110,16 +110,24 @@ public actor DataFrame: Sendable { } } - private func analyzePlanIfNeeded() async throws { - if self._schema != nil { - return - } + private func withGPRC( + _ f: (GRPCClient) async throws -> Result + ) async throws -> Result { try await withGRPCClient( transport: .http2NIOPosix( target: .dns(host: spark.client.host, port: spark.client.port), transportSecurity: .plaintext ) ) { client in + return try await f(client) + } + } + + private func analyzePlanIfNeeded() async throws { + if self._schema != nil { + return + } + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) let response = try await service.analyzePlan( spark.client.getAnalyzePlanRequest(spark.sessionID, plan)) @@ -132,12 +140,7 @@ public actor DataFrame: Sendable { public func count() async throws -> Int64 { let counter = Atomic(Int64(0)) - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) try await service.executePlan(spark.client.getExecutePlanRequest(plan)) { response in @@ -151,12 +154,7 @@ public actor DataFrame: Sendable { /// Execute the plan and try to fill `schema` and `batches`. private func execute() async throws { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) try await service.executePlan(spark.client.getExecutePlanRequest(plan)) { response in @@ -394,12 +392,7 @@ public actor DataFrame: Sendable { /// (without any Spark executors). /// - Returns: True if the plan is local. public func isLocal() async throws -> Bool { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) let response = try await service.analyzePlan(spark.client.getIsLocal(spark.sessionID, plan)) return response.isLocal.isLocal @@ -410,12 +403,7 @@ public actor DataFrame: Sendable { /// arrives. /// - Returns: True if a plan is streaming. public func isStreaming() async throws -> Bool { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) let response = try await service.analyzePlan(spark.client.getIsStreaming(spark.sessionID, plan)) return response.isStreaming.isStreaming @@ -439,12 +427,7 @@ public actor DataFrame: Sendable { public func persist(storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK) async throws -> DataFrame { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) _ = try await service.analyzePlan( spark.client.getPersist(spark.sessionID, plan, storageLevel)) @@ -458,12 +441,7 @@ public actor DataFrame: Sendable { /// - Parameter blocking: Whether to block until all blocks are deleted. /// - Returns: A `DataFrame` public func unpersist(blocking: Bool = false) async throws -> DataFrame { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) _ = try await service.analyzePlan(spark.client.getUnpersist(spark.sessionID, plan, blocking)) } @@ -473,12 +451,7 @@ public actor DataFrame: Sendable { public var storageLevel: StorageLevel { get async throws { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) return try await service .analyzePlan(spark.client.getStorageLevel(spark.sessionID, plan)).getStorageLevel.storageLevel.toStorageLevel @@ -505,12 +478,7 @@ public actor DataFrame: Sendable { /// - Parameter mode: the expected output format of plans; /// `simple`, `extended`, `codegen`, `cost`, `formatted`. public func explain(_ mode: String) async throws { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) let response = try await service.analyzePlan(spark.client.getExplain(spark.sessionID, plan, mode)) print(response.explain.explainString) @@ -522,12 +490,7 @@ public actor DataFrame: Sendable { /// results. Depending on the source relations, this may not find all input files. Duplicates are removed. /// - Returns: An array of file path strings. public func inputFiles() async throws -> [String] { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) let response = try await service.analyzePlan(spark.client.getInputFiles(spark.sessionID, plan)) return response.inputFiles.files @@ -542,12 +505,7 @@ public actor DataFrame: Sendable { /// Prints the schema up to the given level to the console in a nice tree format. /// - Parameter level: A level to be printed. public func printSchema(_ level: Int32) async throws { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: spark.client.host, port: spark.client.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = Spark_Connect_SparkConnectService.Client(wrapping: client) let response = try await service.analyzePlan(spark.client.getTreeString(spark.sessionID, plan, level)) print(response.treeString.treeString) diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index e058177..4854159 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -52,12 +52,7 @@ public actor SparkConnectClient { /// - Parameter sessionID: A string for the session ID. /// - Returns: An `AnalyzePlanResponse` instance for `SparkVersion` func connect(_ sessionID: String) async throws -> AnalyzePlanResponse { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: self.host, port: self.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in // To prevent server-side `INVALID_HANDLE.FORMAT (SQLSTATE: HY000)` exception. if UUID(uuidString: sessionID) == nil { throw SparkConnectError.InvalidSessionIDException @@ -73,6 +68,19 @@ public actor SparkConnectClient { } } + private func withGPRC( + _ f: (GRPCClient) async throws -> Result + ) async throws -> Result { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: self.host, port: self.port), + transportSecurity: .plaintext + ) + ) { client in + return try await f(client) + } + } + /// Create a ``ConfigRequest`` instance for `Set` operation. /// - Parameter map: A map of key-value string pairs. /// - Returns: A ``ConfigRequest`` instance. @@ -89,12 +97,7 @@ public actor SparkConnectClient { /// - Parameter map: A map of key-value pairs to set. /// - Returns: Always return true. func setConf(map: [String: String]) async throws -> Bool { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: self.host, port: self.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) var request = getConfigRequestSet(map: map) request.clientType = clientType @@ -118,12 +121,7 @@ public actor SparkConnectClient { } func unsetConf(keys: [String]) async throws -> Bool { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: self.host, port: self.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) var request = getConfigRequestUnset(keys: keys) request.clientType = clientType @@ -150,12 +148,7 @@ public actor SparkConnectClient { /// - Parameter key: A string for key to look up. /// - Returns: A string for the value of the key. func getConf(_ key: String) async throws -> String { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: self.host, port: self.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) var request = getConfigRequestGet(keys: [key]) request.clientType = clientType @@ -179,12 +172,7 @@ public actor SparkConnectClient { /// Request the server to get all configurations. /// - Returns: A map of key-value pairs. func getConfAll() async throws -> [String: String] { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: self.host, port: self.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) var request = getConfigRequestGetAll() request.clientType = clientType @@ -451,12 +439,7 @@ public actor SparkConnectClient { func execute(_ sessionID: String, _ command: Command) async throws -> [ExecutePlanResponse] { self.result.removeAll() - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: self.host, port: self.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) var plan = Plan() plan.opType = .command(command) @@ -501,12 +484,7 @@ public actor SparkConnectClient { /// - Parameter ddlString: A string to parse. /// - Returns: A ``Spark_Connect_DataType`` instance. func ddlParse(_ ddlString: String) async throws -> Spark_Connect_DataType { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: self.host, port: self.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) let request = analyze(self.sessionID!, { var ddlParse = AnalyzePlanRequest.DDLParse() @@ -522,12 +500,7 @@ public actor SparkConnectClient { /// - Parameter jsonString: A JSON string. /// - Returns: A DDL string. func jsonToDdl(_ jsonString: String) async throws -> String { - try await withGRPCClient( - transport: .http2NIOPosix( - target: .dns(host: self.host, port: self.port), - transportSecurity: .plaintext - ) - ) { client in + try await withGPRC { client in let service = SparkConnectService.Client(wrapping: client) let request = analyze(self.sessionID!, { var jsonToDDL = AnalyzePlanRequest.JsonToDDL()