diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 208601f..023265c 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -1183,6 +1183,37 @@ public actor SparkConnectClient { return try await execute(self.sessionID!, command) } + @discardableResult + func createDataflowGraph( + _ defaultCatalog: String? = nil, + _ defaultDatabase: String? = nil, + _ sqlConf: [String: String]? = nil + ) async throws -> String { + try await withGPRC { client in + var graph = Spark_Connect_PipelineCommand.CreateDataflowGraph() + if let defaultCatalog { + graph.defaultCatalog = defaultCatalog + } + if let defaultDatabase { + graph.defaultDatabase = defaultDatabase + } + if let sqlConf { + graph.sqlConf = sqlConf + } + + var pipelineCommand = Spark_Connect_PipelineCommand() + pipelineCommand.commandType = .createDataflowGraph(graph) + + var command = Spark_Connect_Command() + command.commandType = .pipelineCommand(pipelineCommand) + + let response = try await execute(self.sessionID!, command) + let result = response.first!.pipelineCommandResult.createDataflowGraphResult + + return result.dataflowGraphID + } + } + private enum URIParams { static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size" static let PARAM_SESSION_ID = "session_id" diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index cd57905..58702b1 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -96,4 +96,15 @@ struct SparkConnectClientTests { } await client.stop() } + + @Test + func createDataflowGraph() async throws { + let client = SparkConnectClient(remote: TEST_REMOTE) + let response = try await client.connect(UUID().uuidString) + if response.sparkVersion.version.starts(with: "4.1") { + let dataflowGraphID = try await client.createDataflowGraph() + #expect(UUID(uuidString: dataflowGraphID) != nil) + } + await client.stop() + } }