From 6520cf924b334945e67a7e4653ad6b79700b83b0 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 10 Jul 2025 00:04:55 -0700 Subject: [PATCH] [SPARK-52748] Support `defineDataset` --- Sources/SparkConnect/Extension.swift | 12 +++++++ Sources/SparkConnect/SparkConnectClient.swift | 35 +++++++++++++++++++ Sources/SparkConnect/SparkConnectError.swift | 1 + Sources/SparkConnect/TypeAliases.swift | 1 + .../SparkConnectClientTests.swift | 22 ++++++++++++ 5 files changed, 71 insertions(+) diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index 4307a94..5ae22e8 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -181,6 +181,18 @@ extension String { default: .UNRECOGNIZED(-1) } } + + var toDatasetType: DatasetType { + let mode = + switch self { + case "unspecified": DatasetType.unspecified + case "materializedView": DatasetType.materializedView + case "table": DatasetType.table + case "temporaryView": DatasetType.temporaryView + default: DatasetType.UNRECOGNIZED(-1) + } + return mode + } } extension [String: String] { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index e86e8ba..109ee0d 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -145,6 +145,8 @@ public actor SparkConnectClient { throw SparkConnectError.InvalidViewName case let m where m.contains("DATA_SOURCE_NOT_FOUND"): throw SparkConnectError.DataSourceNotFound + case let m where m.contains("DATASET_TYPE_UNSPECIFIED"): + throw SparkConnectError.DatasetTypeUnspecified default: throw error } @@ -1237,6 +1239,39 @@ public actor SparkConnectClient { } } + @discardableResult + func defineDataset( + _ dataflowGraphID: String, + _ datasetName: String, + _ datasetType: String, + _ comment: String? = nil + ) async throws -> Bool { + try await withGPRC { client in + if UUID(uuidString: dataflowGraphID) == nil { + throw SparkConnectError.InvalidArgument + } + + var defineDataset = Spark_Connect_PipelineCommand.DefineDataset() + defineDataset.dataflowGraphID = dataflowGraphID + defineDataset.datasetName = datasetName + defineDataset.datasetType = datasetType.toDatasetType + if let comment { + defineDataset.comment = comment + } + + var pipelineCommand = Spark_Connect_PipelineCommand() + pipelineCommand.commandType = .defineDataset(defineDataset) + + var command = Spark_Connect_Command() + command.commandType = .pipelineCommand(pipelineCommand) + + let responses = try await execute(self.sessionID!, command) + return responses.contains { + $0.responseType == .pipelineCommandResult(Spark_Connect_PipelineCommandResult()) + } + } + } + private enum URIParams { static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size" static let PARAM_SESSION_ID = "session_id" diff --git a/Sources/SparkConnect/SparkConnectError.swift b/Sources/SparkConnect/SparkConnectError.swift index 9c8d1c6..dde93c3 100644 --- a/Sources/SparkConnect/SparkConnectError.swift +++ b/Sources/SparkConnect/SparkConnectError.swift @@ -22,6 +22,7 @@ public enum SparkConnectError: Error { case CatalogNotFound case ColumnNotFound case DataSourceNotFound + case DatasetTypeUnspecified case InvalidArgument case InvalidSessionID case InvalidType diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index b061f32..c0bacdb 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -23,6 +23,7 @@ typealias AnalyzePlanResponse = Spark_Connect_AnalyzePlanResponse typealias Command = Spark_Connect_Command typealias ConfigRequest = Spark_Connect_ConfigRequest typealias DataSource = Spark_Connect_Read.DataSource +typealias DatasetType = Spark_Connect_DatasetType typealias DataType = Spark_Connect_DataType typealias DayTimeInterval = Spark_Connect_DataType.DayTimeInterval typealias Drop = Spark_Connect_Drop diff --git a/Tests/SparkConnectTests/SparkConnectClientTests.swift b/Tests/SparkConnectTests/SparkConnectClientTests.swift index 955a9c8..1983be2 100644 --- a/Tests/SparkConnectTests/SparkConnectClientTests.swift +++ b/Tests/SparkConnectTests/SparkConnectClientTests.swift @@ -124,4 +124,26 @@ struct SparkConnectClientTests { } await client.stop() } + + @Test + func defineDataset() async throws { + let client = SparkConnectClient(remote: TEST_REMOTE) + let response = try await client.connect(UUID().uuidString) + + try await #require(throws: SparkConnectError.InvalidArgument) { + try await client.defineDataset("not-a-uuid-format", "ds1", "table") + } + + if response.sparkVersion.version.starts(with: "4.1") { + let dataflowGraphID = try await client.createDataflowGraph() + #expect(UUID(uuidString: dataflowGraphID) != nil) + try await #require(throws: SparkConnectError.DatasetTypeUnspecified) { + try await client.defineDataset(dataflowGraphID, "ds1", "unspecified") + } + #expect(try await client.defineDataset(dataflowGraphID, "ds2", "materializedView")) + #expect(try await client.defineDataset(dataflowGraphID, "ds3", "table")) + #expect(try await client.defineDataset(dataflowGraphID, "ds4", "temporaryView")) + } + await client.stop() + } }