From 19717f2dcc3563a92d66988f69fe1686fe9172f2 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 10 Jul 2025 06:28:12 -0700 Subject: [PATCH] [SPARK-52756] Support `defineFlow` --- Sources/SparkConnect/SparkConnectClient.swift | 31 +++++++++++++++++++ .../SparkConnectClientTests.swift | 18 +++++++++++ 2 files changed, 49 insertions(+) diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 109ee0d..8ea6250 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -1272,6 +1272,37 @@ public actor SparkConnectClient { } } + @discardableResult + func defineFlow( + _ dataflowGraphID: String, + _ flowName: String, + _ targetDatasetName: String, + _ relation: Relation + ) async throws -> Bool { + try await withGPRC { client in + if UUID(uuidString: dataflowGraphID) == nil { + throw SparkConnectError.InvalidArgument + } + + var defineFlow = Spark_Connect_PipelineCommand.DefineFlow() + defineFlow.dataflowGraphID = dataflowGraphID + defineFlow.flowName = flowName + defineFlow.targetDatasetName = targetDatasetName + defineFlow.plan = relation + + var pipelineCommand = Spark_Connect_PipelineCommand() + pipelineCommand.commandType = .defineFlow(defineFlow) + + var command = Spark_Connect_Command() + command.commandType = .pipelineCommand(pipelineCommand) + + let responses = try await execute(self.sessionID!, command) + return responses.contains { + $0.responseType == .pipelineCommandResult(Spark_Connect_PipelineCommandResult()) + } + } + } + private enum URIParams { static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size" static let PARAM_SESSION_ID = "session_id" diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index 1983be2..72e31ba 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -146,4 +146,22 @@ struct SparkConnectClientTests { } await client.stop() } + + @Test + func defineFlow() async throws { + let client = SparkConnectClient(remote: TEST_REMOTE) + let response = try await client.connect(UUID().uuidString) + + try await #require(throws: SparkConnectError.InvalidArgument) { + try await client.defineFlow("not-a-uuid-format", "f1", "ds1", Relation()) + } + + if response.sparkVersion.version.starts(with: "4.1") { + let dataflowGraphID = try await client.createDataflowGraph() + #expect(UUID(uuidString: dataflowGraphID) != nil) + let relation = await client.getLocalRelation().root + #expect(try await client.defineFlow(dataflowGraphID, "f1", "ds1", relation)) + } + await client.stop() + } }