diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index ee96d6d83f90e..4e464ddadbaa1 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -8556,6 +8556,12 @@ ], "sqlState" : "42KDF" }, + "ZIP_PLANS_NOT_MERGEABLE" : { + "message" : [ + "The two DataFrames in zip() cannot be merged because they do not derive from the same base plan through Project operations." + ], + "sqlState" : "42K03" + }, "_LEGACY_ERROR_TEMP_0001" : { "message" : [ "Invalid InsertIntoContext." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala index 0f1fe314c3500..1e2d79ea6b743 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -819,6 +819,22 @@ abstract class Dataset[T] extends Serializable { */ def crossJoin(right: Dataset[_]): DataFrame + /** + * Combines the columns of this DataFrame with another DataFrame that derives from the same + * base plan through Project operations. The analyzer rewrites the resulting Zip node into a + * single Project over the shared base plan. + * + * @param other + * Another DataFrame that shares the same base plan. + * @return + * A new DataFrame with columns from both sides. + * @throws AnalysisException + * if the two DataFrames do not derive from the same base plan. + * @group untypedrel + * @since 4.1.0 + */ + def zip(other: Dataset[_]): DataFrame + /** * Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to true. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3b4d725840935..333e1e88c3d0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -498,6 +498,7 @@ class Analyzer( ResolveBinaryArithmetic :: new ResolveIdentifierClause(earlyBatches) :: ResolveUnion :: + ResolveZip :: FlattenSequentialStreamingUnion :: ValidateSequentialStreamingUnion :: ResolveRowLevelCommandAssignments :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 17fac640d4832..ab6fbd4f1ad14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -637,6 +637,19 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString messageParameters = Map.empty) } + case z: Zip => + def stripProjects(plan: LogicalPlan): LogicalPlan = plan match { + case Project(_, child) => stripProjects(child) + case other => other + } + val leftBase = stripProjects(z.left) + val rightBase = stripProjects(z.right) + if (!leftBase.sameResult(rightBase)) { + z.failAnalysis( + errorClass = "ZIP_PLANS_NOT_MERGEABLE", + messageParameters = Map.empty) + } + case a: Aggregate => a.groupingExpressions.foreach( expression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveZip.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveZip.scala new file mode 100644 index 0000000000000..171269f6a3c42 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveZip.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, NamedExpression, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Zip} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.ZIP + +/** + * Resolves a [[Zip]] node by rewriting it into a single [[Project]] over the shared base plan. + * + * Both children of Zip must derive from the same base plan through chains of scalar Project + * nodes (1:1 row mapping). `Project.resolved` already rejects Generator, AggregateExpression, + * and WindowExpression. This rule additionally rejects non-scalar Python UDFs (e.g. + * GROUPED_MAP), which are not caught by `Project.resolved`. + * + * This rule: + * 1. Waits for both children to be resolved + * 2. Strips Project layers from each side to find the base plan + * 3. Verifies the base plans produce the same result (via `sameResult`) + * 4. Verifies neither side contains a non-scalar Python UDF + * 5. Remaps the right side's attribute references to the left base plan's output + * 6. Produces a single Project that combines both sides' expressions + * + * If the base plans do not match, or a non-scalar Python UDF is present, the Zip node remains + * unresolved and CheckAnalysis will report a [[ZIP_PLANS_NOT_MERGEABLE]] error. + */ +object ResolveZip extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsPattern(ZIP), ruleId) { + case z: Zip if z.childrenResolved => + val (leftExprs, leftBase) = extractProjectAndBase(z.left) + val (rightExprs, rightBase) = extractProjectAndBase(z.right) + if (leftBase.sameResult(rightBase) && allScalar(leftExprs ++ rightExprs)) { + // Build an attribute mapping from rightBase output to leftBase output (by position) + val attrMapping = AttributeMap(rightBase.output.zip(leftBase.output)) + // Remap right expressions to reference leftBase's attributes + val remappedRightExprs = rightExprs.map { expr => + expr.transform { + case a: Attribute => attrMapping.getOrElse(a, a) + }.asInstanceOf[NamedExpression] + } + Project(leftExprs ++ remappedRightExprs, leftBase) + } else { + z + } + } + + private def extractProjectAndBase( + plan: LogicalPlan): (Seq[NamedExpression], LogicalPlan) = plan match { + case Project(projectList, child) => (projectList, child) + case other => (other.output, other) + } + + /** + * Returns true if all expressions are scalar (1:1 row mapping). + * `Project.resolved` already rejects Generator, AggregateExpression, and WindowExpression. + * This additionally rejects non-scalar Python UDFs (e.g. GROUPED_MAP) that can break + * the 1:1 row mapping. + */ + private def allScalar(exprs: Seq[NamedExpression]): Boolean = { + !exprs.exists(_.exists { + case udf: PythonUDF => !PythonUDF.isScalarPythonUDF(udf) + case _ => false + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 8e9f264698caf..f27cc4a9ead51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -829,6 +829,29 @@ case class Join( newLeft: LogicalPlan, newRight: LogicalPlan): Join = copy(left = newLeft, right = newRight) } +/** + * A logical plan that combines the columns of two DataFrames that derive from the same + * base plan through chains of Project nodes. This node is always unresolved and must be + * rewritten by [[ResolveZip]] into a single Project over the shared base plan during + * analysis. If the two children do not share the same base plan (after stripping Project + * nodes), analysis will fail with an error. + */ +case class Zip(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output + + override def maxRows: Option[Long] = left.maxRows + + override def maxRowsPerPartition: Option[Long] = left.maxRowsPerPartition + + final override val nodePatterns: Seq[TreePattern] = Seq(ZIP) + + // Always unresolved -- must be rewritten by ResolveZip during analysis. + override lazy val resolved: Boolean = false + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): Zip = copy(left = newLeft, right = newRight) +} + /** * Insert query result into a directory. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 4ed918328a16b..b9873b390aa4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -104,6 +104,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.ResolveTableSpec" :: "org.apache.spark.sql.catalyst.analysis.ResolveTimeZone" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnion" :: + "org.apache.spark.sql.catalyst.analysis.ResolveZip" :: "org.apache.spark.sql.catalyst.analysis.ResolveUnresolvedHaving" :: "org.apache.spark.sql.catalyst.analysis.ResolveUpdateEventTimeWatermarkColumn" :: "org.apache.spark.sql.catalyst.analysis.ResolveWindowTime" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 1e22c1ce86539..b3d96da1cb52a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -176,6 +176,7 @@ object TreePattern extends Enumeration { val TRANSPOSE: Value = Value val UNION: Value = Value val UNPIVOT: Value = Value + val ZIP: Value = Value val UPDATE_EVENT_TIME_WATERMARK_COLUMN: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveZipSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveZipSuite.scala new file mode 100644 index 0000000000000..5e1235cfec231 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveZipSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ResolveZipSuite extends AnalysisTest { + + private val base = LocalRelation($"a".int, $"b".int, $"c".int) + + object Resolve extends RuleExecutor[LogicalPlan] { + override val batches: Seq[Batch] = Seq( + Batch("ResolveZip", Once, ResolveZip)) + } + + test("resolve Zip: both sides have Project over same base") { + val left = Project(Seq(base.output(0)), base) + val right = Project(Seq(base.output(1)), base) + val zip = Zip(left, right) + + val resolved = Resolve.execute(zip) + val expected = Project(Seq(base.output(0), base.output(1)), base) + comparePlans(resolved, expected) + } + + test("resolve Zip: left is bare plan, right has Project") { + val right = Project(Seq(base.output(0)), base) + val zip = Zip(base, right) + + val resolved = Resolve.execute(zip) + val expected = Project(base.output ++ Seq(base.output(0)), base) + comparePlans(resolved, expected) + } + + test("resolve Zip: both sides are bare same plan") { + val zip = Zip(base, base) + + val resolved = Resolve.execute(zip) + val expected = Project(base.output ++ base.output, base) + comparePlans(resolved, expected) + } + + test("resolve Zip: both sides have expressions over same base") { + val left = base.select(($"a" + 1).as("a_plus_1")) + val right = base.select(($"b" * 2).as("b_times_2")) + val zip = Zip(left.analyze, right.analyze) + + val resolved = Resolve.execute(zip) + assert(!resolved.isInstanceOf[Zip], "Zip should have been resolved to a Project") + assert(resolved.isInstanceOf[Project]) + assert(resolved.output.length == 2) + assert(resolved.output(0).name == "a_plus_1") + assert(resolved.output(1).name == "b_times_2") + } + + test("resolve Zip: different base plans - Zip remains unresolved") { + val base2 = LocalRelation($"x".int, $"y".int, $"z".int, $"w".int) + val left = Project(Seq(base.output(0)), base) + val right = Project(Seq(base2.output(0)), base2) + val zip = Zip(left, right) + + val resolved = Resolve.execute(zip) + // ResolveZip cannot merge, so Zip stays + assert(resolved.isInstanceOf[Zip]) + } + + test("resolve Zip: skipped when children are unresolved") { + val unresolvedChild = Project( + Seq(UnresolvedAttribute("a")), + UnresolvedRelation(Seq("t"))) + val zip = Zip(unresolvedChild, unresolvedChild) + + val result = Resolve.execute(zip) + // Zip should remain unchanged because children are not resolved + assert(result.isInstanceOf[Zip]) + } + + test("CheckAnalysis: different base plans throws ZIP_PLANS_NOT_MERGEABLE") { + val base2 = LocalRelation($"x".int, $"y".int, $"z".int, $"w".int) + val left = Project(Seq(base.output(0)), base) + val right = Project(Seq(base2.output(0)), base2) + val zip = Zip(left, right) + + assertAnalysisErrorCondition( + zip, + expectedErrorCondition = "ZIP_PLANS_NOT_MERGEABLE", + expectedMessageParameters = Map.empty + ) + } +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala index e9595dc64e9f0..841f6c7b94cac 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/Dataset.scala @@ -345,6 +345,11 @@ class Dataset[T] private[sql] ( builder.setJoinType(proto.Join.JoinType.JOIN_TYPE_CROSS) } + /** @inheritdoc */ + def zip(other: sql.Dataset[_]): DataFrame = { + throw new UnsupportedOperationException("zip is not supported in Spark Connect") + } + /** @inheritdoc */ def joinWith[U](other: sql.Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { val joinTypeValue = toJoinType(joinType, skipSemiAnti = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala index 2070873f96579..b4010f78eb527 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/Dataset.scala @@ -707,6 +707,11 @@ class Dataset[T] private[sql]( Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) } + /** @inheritdoc */ + def zip(other: sql.Dataset[_]): DataFrame = withPlan { + Zip(logicalPlan, other.logicalPlan) + } + /** @inheritdoc */ def joinWith[U](other: sql.Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameZipSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameZipSuite.scala new file mode 100644 index 0000000000000..bf5b2fdcf1eb0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameZipSuite.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.test.SharedSparkSession + +class DataFrameZipSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("zip: select different columns from the same DataFrame") { + val df = Seq((1, 2, 3), (4, 5, 6), (7, 8, 9)).toDF("a", "b", "c") + val left = df.select("a") + val right = df.select("b") + + checkAnswer( + left.zip(right), + Row(1, 2) :: Row(4, 5) :: Row(7, 8) :: Nil) + } + + test("zip: select with expressions over the same DataFrame") { + val df = Seq((1, 10), (2, 20), (3, 30)).toDF("a", "b") + val left = df.select(($"a" + 1).as("a_plus_1")) + val right = df.select(($"b" * 2).as("b_times_2")) + + checkAnswer( + left.zip(right), + Row(2, 20) :: Row(3, 40) :: Row(4, 60) :: Nil) + } + + test("zip: one side selects all columns") { + val df = Seq((1, 2), (3, 4)).toDF("a", "b") + val right = df.select(($"a" + $"b").as("sum")) + + checkAnswer( + df.zip(right), + Row(1, 2, 3) :: Row(3, 4, 7) :: Nil) + } + + test("zip: resolved plan is a Project") { + val df = Seq((1, 2)).toDF("a", "b") + val left = df.select("a") + val right = df.select("b") + val zipped = left.zip(right) + + assert(zipped.queryExecution.analyzed.isInstanceOf[Project]) + } + + test("zip: different base plans throws AnalysisException") { + val df1 = Seq((1, 2)).toDF("a", "b") + val df2 = Seq((3, 4, 5)).toDF("x", "y", "z") + + checkError( + exception = intercept[AnalysisException] { + df1.select("a").zip(df2.select("x")).queryExecution.assertAnalyzed() + }, + condition = "ZIP_PLANS_NOT_MERGEABLE" + ) + } + + test("zip: different base plans from spark.range throws AnalysisException") { + val df1 = spark.range(10).toDF("id1") + val df2 = spark.range(20).toDF("id2") + + checkError( + exception = intercept[AnalysisException] { + df1.zip(df2).queryExecution.assertAnalyzed() + }, + condition = "ZIP_PLANS_NOT_MERGEABLE" + ) + } +}