From 68687b6e414c529b085ee1f5eec8c8a054ed0e80 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 27 Mar 2025 13:25:06 -0700 Subject: [PATCH] [SPARK-51642] Support `explain` for `DataFrame` --- Sources/SparkConnect/DataFrame.swift | 25 +++++++++++++++++++ Sources/SparkConnect/Extension.swift | 12 +++++++++ Sources/SparkConnect/SparkConnectClient.swift | 12 +++++++++ Sources/SparkConnect/TypeAliases.swift | 1 + Tests/SparkConnectTests/DataFrameTests.swift | 9 +++++++ 5 files changed, 59 insertions(+) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 504d8c2..b1ec758 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -330,4 +330,29 @@ public actor DataFrame: Sendable { return self } + + public func explain() async throws { + try await explain("simple") + } + + public func explain(_ extended: Bool) async throws { + if (extended) { + try await explain("extended") + } else { + try await explain("simple") + } + } + + public func explain(_ mode: String) async throws { + try await withGRPCClient( + transport: .http2NIOPosix( + target: .dns(host: spark.client.host, port: spark.client.port), + transportSecurity: .plaintext + ) + ) { client in + let service = Spark_Connect_SparkConnectService.Client(wrapping: client) + let response = try await service.analyzePlan(spark.client.getExplain(spark.sessionID, plan, mode)) + print(response.explain.explainString) + } + } } diff --git a/Sources/SparkConnect/Extension.swift b/Sources/SparkConnect/Extension.swift index da330cc..1d470fe 100644 --- a/Sources/SparkConnect/Extension.swift +++ b/Sources/SparkConnect/Extension.swift @@ -57,6 +57,18 @@ extension String { expression.expression = self return expression } + + var toExplainMode: ExplainMode { + let mode = switch self { + case "codegen": ExplainMode.codegen + case "cost": ExplainMode.cost + case "extended": ExplainMode.extended + case "formatted": ExplainMode.formatted + case "simple": ExplainMode.simple + default: ExplainMode.simple + } + return mode + } } extension [String: String] { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 0b2d523..4d28c34 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -282,6 +282,18 @@ public actor SparkConnectClient { }) } + func getExplain(_ sessionID: String, _ plan: Plan, _ mode: String) async -> AnalyzePlanRequest + { + return analyze( + sessionID, + { + var explain = AnalyzePlanRequest.Explain() + explain.plan = plan + explain.explainMode = mode.toExplainMode + return OneOf_Analyze.explain(explain) + }) + } + static func getProject(_ child: Relation, _ cols: [String]) -> Plan { var project = Project() project.input = child diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 275ed9d..aa1e087 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -22,6 +22,7 @@ typealias ConfigRequest = Spark_Connect_ConfigRequest typealias DataSource = Spark_Connect_Read.DataSource typealias DataType = Spark_Connect_DataType typealias ExecutePlanRequest = Spark_Connect_ExecutePlanRequest +typealias ExplainMode = AnalyzePlanRequest.Explain.ExplainMode typealias ExpressionString = Spark_Connect_Expression.ExpressionString typealias Filter = Spark_Connect_Filter typealias KeyValue = Spark_Connect_KeyValue diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index 29dddb4..87b8fa4 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -70,6 +70,15 @@ struct DataFrameTests { await spark.stop() } + @Test + func explain() async throws { + let spark = try await SparkSession.builder.getOrCreate() + try await spark.range(1).explain() + try await spark.range(1).explain(true) + try await spark.range(1).explain("formatted") + await spark.stop() + } + @Test func count() async throws { let spark = try await SparkSession.builder.getOrCreate()