From eb91d661119e50fa1497de866b51e7d6e1aae479 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 23 Apr 2025 17:26:21 +0900 Subject: [PATCH 1/2] [SPARK-51875] Support `repartition(ByExpression)?` and `coalesce` --- Sources/SparkConnect/DataFrame.swift | 61 ++++++++++++++++++- Sources/SparkConnect/SparkConnectClient.swift | 32 ++++++++++ Sources/SparkConnect/TypeAliases.swift | 2 + Tests/SparkConnectTests/DataFrameTests.swift | 47 ++++++++++++++ 4 files changed, 140 insertions(+), 2 deletions(-) diff --git a/Sources/SparkConnect/DataFrame.swift b/Sources/SparkConnect/DataFrame.swift index 31655b2..83dbba1 100644 --- a/Sources/SparkConnect/DataFrame.swift +++ b/Sources/SparkConnect/DataFrame.swift @@ -538,7 +538,7 @@ public actor DataFrame: Sendable { /// - right: Right side of the join operation. /// - usingColumn: Name of the column to join on. This column must exist on both sides. /// - joinType: Type of join to perform. Default `inner`. - /// - Returns: <#description#> + /// - Returns: A `DataFrame`. public func join(_ right: DataFrame, _ usingColumn: String, _ joinType: String = "inner") async -> DataFrame { await join(right, [usingColumn], joinType) } @@ -588,7 +588,7 @@ public actor DataFrame: Sendable { /// Explicit cartesian join with another `DataFrame`. /// - Parameter right: Right side of the join operation. - /// - Returns: Cartesian joins are very expensive without an extra filter that can be pushed down. + /// - Returns: A `DataFrame`. public func crossJoin(_ right: DataFrame) async -> DataFrame { let rightPlan = await (right.getPlan() as! Plan).root let plan = SparkConnectClient.getJoin(self.plan.root, rightPlan, JoinType.cross) @@ -676,6 +676,63 @@ public actor DataFrame: Sendable { return DataFrame(spark: self.spark, plan: plan) } + private func buildRepartition(numPartitions: Int32, shuffle: Bool) -> DataFrame { + let plan = SparkConnectClient.getRepartition(self.plan.root, numPartitions, shuffle) + return DataFrame(spark: self.spark, plan: plan) + } + + private func buildRepartitionByExpression(numPartitions: Int32?, partitionExprs: [String]) -> DataFrame { + let plan = SparkConnectClient.getRepartitionByExpression(self.plan.root, partitionExprs, numPartitions) + return DataFrame(spark: self.spark, plan: plan) + } + + /// Returns a new ``DataFrame`` that has exactly `numPartitions` partitions. + /// - Parameter numPartitions: The number of partitions. + /// - Returns: A `DataFrame`. + public func repartition(_ numPartitions: Int32) -> DataFrame { + return buildRepartition(numPartitions: numPartitions, shuffle: true) + } + + /// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using + /// `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash + /// partitioned. + /// - Parameter partitionExprs: The partition expression strings. + /// - Returns: A `DataFrame`. + public func repartition(_ partitionExprs: String...) -> DataFrame { + return buildRepartitionByExpression(numPartitions: nil, partitionExprs: partitionExprs) + } + + /// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using + /// `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash + /// partitioned. + /// - Parameters: + /// - numPartitions: The number of partitions. + /// - partitionExprs: The partition expression strings. + /// - Returns: A `DataFrame`. + public func repartition(_ numPartitions: Int32, _ partitionExprs: String...) -> DataFrame { + return buildRepartitionByExpression(numPartitions: numPartitions, partitionExprs: partitionExprs) + } + + /// Returns a new ``DataFrame`` partitioned by the given partitioning expressions, using + /// `spark.sql.shuffle.partitions` as number of partitions. The resulting Dataset is hash + /// partitioned. + /// - Parameter partitionExprs: The partition expression strings. + /// - Returns: A `DataFrame`. + public func repartitionByExpression(_ numPartitions: Int32?, _ partitionExprs: String...) -> DataFrame { + return buildRepartitionByExpression(numPartitions: numPartitions, partitionExprs: partitionExprs) + } + + /// Returns a new ``DataFrame`` that has exactly `numPartitions` partitions, when the fewer partitions + /// are requested. If a larger number of partitions is requested, it will stay at the current + /// number of partitions. Similar to coalesce defined on an `RDD`, this operation results in a + /// narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not be a + /// shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. + /// - Parameter numPartitions: The number of partitions. + /// - Returns: A `DataFrame`. + public func coalesce(_ numPartitions: Int32) -> DataFrame { + return buildRepartition(numPartitions: numPartitions, shuffle: false) + } + /// Returns a ``DataFrameWriter`` that can be used to write non-streaming data. public var write: DataFrameWriter { get { diff --git a/Sources/SparkConnect/SparkConnectClient.swift b/Sources/SparkConnect/SparkConnectClient.swift index 28f7eba..38d7df2 100644 --- a/Sources/SparkConnect/SparkConnectClient.swift +++ b/Sources/SparkConnect/SparkConnectClient.swift @@ -628,6 +628,38 @@ public actor SparkConnectClient { }) } + static func getRepartition(_ child: Relation, _ numPartitions: Int32, _ shuffle: Bool = false) -> Plan { + var repartition = Repartition() + repartition.input = child + repartition.numPartitions = numPartitions + repartition.shuffle = shuffle + var relation = Relation() + relation.repartition = repartition + var plan = Plan() + plan.opType = .root(relation) + return plan + } + + static func getRepartitionByExpression( + _ child: Relation, _ partitionExprs: [String], _ numPartitions: Int32? = nil + ) -> Plan { + var repartitionByExpression = RepartitionByExpression() + repartitionByExpression.input = child + repartitionByExpression.partitionExprs = partitionExprs.map { + var expr = Spark_Connect_Expression() + expr.expressionString = $0.toExpressionString + return expr + } + if let numPartitions { + repartitionByExpression.numPartitions = numPartitions + } + var relation = Relation() + relation.repartitionByExpression = repartitionByExpression + var plan = Plan() + plan.opType = .root(relation) + return plan + } + private enum URIParams { static let PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size" static let PARAM_SESSION_ID = "session_id" diff --git a/Sources/SparkConnect/TypeAliases.swift b/Sources/SparkConnect/TypeAliases.swift index 1107c52..2858de2 100644 --- a/Sources/SparkConnect/TypeAliases.swift +++ b/Sources/SparkConnect/TypeAliases.swift @@ -42,6 +42,8 @@ typealias Project = Spark_Connect_Project typealias Range = Spark_Connect_Range typealias Read = Spark_Connect_Read typealias Relation = Spark_Connect_Relation +typealias Repartition = Spark_Connect_Repartition +typealias RepartitionByExpression = Spark_Connect_RepartitionByExpression typealias Sample = Spark_Connect_Sample typealias SaveMode = Spark_Connect_WriteOperation.SaveMode typealias SetOperation = Spark_Connect_SetOperation diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index ee220c3..a98a936 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -17,6 +17,7 @@ // under the License. // +import Foundation import Testing import SparkConnect @@ -530,6 +531,52 @@ struct DataFrameTests { #expect(try await df3.unionByName(df3).count() == 4) await spark.stop() } + + @Test + func repartition() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tmpDir = "/tmp/" + UUID().uuidString + let df = try await spark.range(2025) + for n in [1, 3, 5] as [Int32] { + try await df.repartition(n).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + } + try await spark.range(1).repartition(10).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + await spark.stop() + } + + @Test + func repartitionByExpression() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tmpDir = "/tmp/" + UUID().uuidString + let df = try await spark.range(2025) + for n in [1, 3, 5] as [Int32] { + try await df.repartition(n, "id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + try await df.repartitionByExpression(n, "id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + } + try await spark.range(1).repartition(10, "id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + try await spark.range(1).repartition("id").write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + await spark.stop() + } + + @Test + func coalesce() async throws { + let spark = try await SparkSession.builder.getOrCreate() + let tmpDir = "/tmp/" + UUID().uuidString + let df = try await spark.range(2025) + for n in [1, 3, 5] as [Int32] { + try await df.coalesce(n).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) + } + try await spark.range(1).coalesce(10).write.mode("overwrite").orc(tmpDir) + #expect(try await spark.read.orc(tmpDir).inputFiles().count < 10) + await spark.stop() + } #endif @Test From e27ce9b959956fc43a479121ec717e8c3a46e976 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 23 Apr 2025 17:36:21 +0900 Subject: [PATCH 2/2] reduce for macos --- Tests/SparkConnectTests/DataFrameTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/SparkConnectTests/DataFrameTests.swift b/Tests/SparkConnectTests/DataFrameTests.swift index a98a936..5772120 100644 --- a/Tests/SparkConnectTests/DataFrameTests.swift +++ b/Tests/SparkConnectTests/DataFrameTests.swift @@ -569,7 +569,7 @@ struct DataFrameTests { let spark = try await SparkSession.builder.getOrCreate() let tmpDir = "/tmp/" + UUID().uuidString let df = try await spark.range(2025) - for n in [1, 3, 5] as [Int32] { + for n in [1, 2, 3] as [Int32] { try await df.coalesce(n).write.mode("overwrite").orc(tmpDir) #expect(try await spark.read.orc(tmpDir).inputFiles().count == n) }