From 8cef24b5881630f5296730b221ab2b0c902c5fbf Mon Sep 17 00:00:00 2001 From: Alex Barreto Date: Wed, 2 Dec 2020 12:48:29 -0600 Subject: [PATCH] [SPARK-33071][SPARK-33536][SQL] Avoid changing dataset_id of LogicalPlan in join() to not break DetectAmbiguousSelfJoin ### What changes were proposed in this pull request? Currently, `join()` uses `withPlan(logicalPlan)` for convenient to call some Dataset functions. But it leads to the `dataset_id` inconsistent between the `logicalPlan` and the original `Dataset`(because `withPlan(logicalPlan)` will create a new Dataset with the new id and reset the `dataset_id` with the new id of the `logicalPlan`). As a result, it breaks the rule `DetectAmbiguousSelfJoin`. In this PR, we propose to drop the usage of `withPlan` but use the `logicalPlan` directly so its `dataset_id` doesn't change. Besides, this PR also removes related metadata (`DATASET_ID_KEY`, `COL_POS_KEY`) when an `Alias` tries to construct its own metadata. Because the `Alias` is no longer a reference column after converting to an `Attribute`. To achieve that, we add a new field, `deniedMetadataKeys`, to indicate the metadata that needs to be removed. ### Why are the changes needed? For the query below, it returns the wrong result while it should throws ambiguous self join exception instead: ```scala val emp1 = Seq[TestData]( TestData(1, "sales"), TestData(2, "personnel"), TestData(3, "develop"), TestData(4, "IT")).toDS() val emp2 = Seq[TestData]( TestData(1, "sales"), TestData(2, "personnel"), TestData(3, "develop")).toDS() val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*")) emp1.join(emp3, emp1.col("key") === emp3.col("key"), "left_outer") .select(emp1.col("*"), emp3.col("key").as("e2")).show() // wrong result +---+---------+---+ |key| value| e2| +---+---------+---+ | 1| sales| 1| | 2|personnel| 2| | 3| develop| 3| | 4| IT| 4| +---+---------+---+ ``` This PR fixes the wrong behaviour. ### Does this PR introduce _any_ user-facing change? Yes, users hit the exception instead of the wrong result after this PR. ### How was this patch tested? Added a new unit test. Closes #30488 from Ngone51/fix-self-join. Authored-by: yi.wu Signed-off-by: Wenchen Fan --- .../integrationtest/DecommissionSuite.scala | 6 +-- .../k8s/integrationtest/DepsTestsSuite.scala | 2 +- .../k8s/integrationtest/KubernetesSuite.scala | 32 ++++++++++-- .../integrationtest/PythonTestsSuite.scala | 6 +-- .../k8s/integrationtest/RTestsSuite.scala | 2 +- .../SparkConfPropagateSuite.scala | 22 ++++----- .../catalyst/expressions/AliasHelper.scala | 3 +- .../expressions/complexTypeCreator.scala | 18 ++++--- .../expressions/complexTypeExtractors.scala | 2 +- .../expressions/datetimeExpressions.scala | 7 ++- .../expressions/intervalExpressions.scala | 14 +++--- .../expressions/namedExpressions.scala | 15 ++++-- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../expressions/ExpressionEvalHelper.scala | 49 ++++++------------- .../ExpressionEvalHelperSuite.scala | 25 +++++++++- .../IntervalExpressionsSuite.scala | 36 +++++++------- .../expressions/MathExpressionsSuite.scala | 5 +- .../optimizer/ConstantFoldingSuite.scala | 2 +- .../InferFiltersFromGenerateSuite.scala | 6 +-- .../scala/org/apache/spark/sql/Column.scala | 5 +- .../scala/org/apache/spark/sql/Dataset.scala | 39 +++++++++------ .../spark/sql/DataFrameSelfJoinSuite.scala | 29 +++++++++++ .../sql/SparkSessionExtensionSuite.scala | 7 +-- 23 files changed, 207 insertions(+), 127 deletions(-) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index 9d7db04bb7..92f6a32cd1 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -38,7 +38,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_DECOMISSIONING, mainClass = "", - expectedLogOnCompletion = Seq( + expectedDriverLogOnCompletion = Seq( "Finished waiting, stopping Spark", "Decommission executors", "Final accumulator value is: 100"), @@ -69,7 +69,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_DECOMISSIONING_CLEANUP, mainClass = "", - expectedLogOnCompletion = Seq( + expectedDriverLogOnCompletion = Seq( "Finished waiting, stopping Spark", "Decommission executors"), appArgs = Array.empty[String], @@ -104,7 +104,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_SCALE, mainClass = "", - expectedLogOnCompletion = Seq( + expectedDriverLogOnCompletion = Seq( "Finished waiting, stopping Spark", "Decommission executors"), appArgs = Array.empty[String], diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala index 8f6e9cd8af..760e9ba55d 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DepsTestsSuite.scala @@ -177,7 +177,7 @@ private[spark] trait DepsTestsSuite { k8sSuite: KubernetesSuite => runSparkApplicationAndVerifyCompletion( appResource = pySparkFiles, mainClass = "", - expectedLogOnCompletion = Seq( + expectedDriverLogOnCompletion = Seq( "Python runtime version check is: True", "Python environment version check is: True", "Python runtime version check for executor is: True"), diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index cc226b3419..193a02aad0 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -171,6 +171,7 @@ class KubernetesSuite extends SparkFunSuite appResource, SPARK_PI_MAIN_CLASS, Seq("Pi is roughly 3"), + Seq(), appArgs, driverPodChecker, executorPodChecker, @@ -192,6 +193,7 @@ class KubernetesSuite extends SparkFunSuite SPARK_DFS_READ_WRITE_TEST, Seq(s"Success! Local Word Count $wordCount and " + s"DFS Word Count $wordCount agree."), + Seq(), appArgs, driverPodChecker, executorPodChecker, @@ -212,6 +214,7 @@ class KubernetesSuite extends SparkFunSuite appResource, SPARK_REMOTE_MAIN_CLASS, Seq(s"Mounting of ${appArgs.head} was true"), + Seq(), appArgs, driverPodChecker, executorPodChecker, @@ -261,7 +264,8 @@ class KubernetesSuite extends SparkFunSuite protected def runSparkApplicationAndVerifyCompletion( appResource: String, mainClass: String, - expectedLogOnCompletion: Seq[String], + expectedDriverLogOnCompletion: Seq[String], + expectedExecutorLogOnCompletion: Seq[String] = Seq(), appArgs: Array[String], driverPodChecker: Pod => Unit, executorPodChecker: Pod => Unit, @@ -374,7 +378,6 @@ class KubernetesSuite extends SparkFunSuite .list() .getItems .get(0) - driverPodChecker(driverPod) // If we're testing decommissioning we an executors, but we should have an executor @@ -383,14 +386,35 @@ class KubernetesSuite extends SparkFunSuite execPods.values.nonEmpty should be (true) } execPods.values.foreach(executorPodChecker(_)) + + val execPod: Option[Pod] = if (expectedExecutorLogOnCompletion.nonEmpty) { + Some(kubernetesTestComponents.kubernetesClient + .pods() + .withLabel("spark-app-locator", appLocator) + .withLabel("spark-role", "executor") + .list() + .getItems + .get(0)) + } else { + None + } + Eventually.eventually(patienceTimeout, patienceInterval) { - expectedLogOnCompletion.foreach { e => + expectedDriverLogOnCompletion.foreach { e => assert(kubernetesTestComponents.kubernetesClient .pods() .withName(driverPod.getMetadata.getName) .getLog .contains(e), - s"The application did not complete, did not find str ${e}") + s"The application did not complete, driver log did not contain str ${e}") + } + expectedExecutorLogOnCompletion.foreach { e => + assert(kubernetesTestComponents.kubernetesClient + .pods() + .withName(execPod.get.getMetadata.getName) + .getLog + .contains(e), + s"The application did not complete, executor log did not contain str ${e}") } } execWatcher.close() diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala index bad6f1c102..457a766cae 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/PythonTestsSuite.scala @@ -27,7 +27,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_PI, mainClass = "", - expectedLogOnCompletion = Seq("Pi is roughly 3"), + expectedDriverLogOnCompletion = Seq("Pi is roughly 3"), appArgs = Array("5"), driverPodChecker = doBasicDriverPyPodCheck, executorPodChecker = doBasicExecutorPyPodCheck, @@ -41,7 +41,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_FILES, mainClass = "", - expectedLogOnCompletion = Seq( + expectedDriverLogOnCompletion = Seq( "Python runtime version check is: True", "Python environment version check is: True", "Python runtime version check for executor is: True"), @@ -61,7 +61,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite => runSparkApplicationAndVerifyCompletion( appResource = PYSPARK_MEMORY_CHECK, mainClass = "", - expectedLogOnCompletion = Seq( + expectedDriverLogOnCompletion = Seq( "PySpark Worker Memory Check is: True"), appArgs = Array(s"$additionalMemoryInBytes"), driverPodChecker = doDriverMemoryCheck, diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala index b7c8886a15..a22066c180 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/RTestsSuite.scala @@ -26,7 +26,7 @@ private[spark] trait RTestsSuite { k8sSuite: KubernetesSuite => runSparkApplicationAndVerifyCompletion( appResource = SPARK_R_DATAFRAME_TEST, mainClass = "", - expectedLogOnCompletion = Seq("name: string (nullable = true)", "1 Justin"), + expectedDriverLogOnCompletion = Seq("name: string (nullable = true)", "1 Justin"), appArgs = Array.empty[String], driverPodChecker = doBasicDriverRPodCheck, executorPodChecker = doBasicExecutorRPodCheck, diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkConfPropagateSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkConfPropagateSuite.scala index 6d15201d19..5d3b426598 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkConfPropagateSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/SparkConfPropagateSuite.scala @@ -16,14 +16,11 @@ */ package org.apache.spark.deploy.k8s.integrationtest -import java.io.{BufferedWriter, File, FileWriter} +import java.io.File import java.net.URL +import java.nio.file.Files -import scala.io.{BufferedSource, Source} - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.internal.config +import scala.io.Source private[spark] trait SparkConfPropagateSuite { k8sSuite: KubernetesSuite => import KubernetesSuite.{k8sTestTag, SPARK_PI_MAIN_CLASS} @@ -38,18 +35,21 @@ private[spark] trait SparkConfPropagateSuite { k8sSuite: KubernetesSuite => val logConfFilePath = s"${sparkHomeDir.toFile}/conf/log4j.properties" try { - val writer = new BufferedWriter(new FileWriter(logConfFilePath)) - writer.write(content) - writer.close() + Files.write(new File(logConfFilePath).toPath, content.getBytes) sparkAppConf.set("spark.driver.extraJavaOptions", "-Dlog4j.debug") + sparkAppConf.set("spark.executor.extraJavaOptions", "-Dlog4j.debug") + + val log4jExpectedLog = + s"log4j: Reading configuration from URL file:/opt/spark/conf/log4j.properties" runSparkApplicationAndVerifyCompletion( appResource = containerLocalSparkDistroExamplesJar, mainClass = SPARK_PI_MAIN_CLASS, - expectedLogOnCompletion = (Seq("DEBUG", - s"log4j: Reading configuration from URL file:/opt/spark/conf/log4j.properties", + expectedDriverLogOnCompletion = (Seq("DEBUG", + log4jExpectedLog, "Pi is roughly 3")), + expectedExecutorLogOnCompletion = Seq(log4jExpectedLog), appArgs = Array.empty[String], driverPodChecker = doBasicDriverPodCheck, executorPodChecker = doBasicExecutorPodCheck, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index ec47875754..c61eb68db5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -89,7 +89,8 @@ trait AliasHelper { a.copy(child = trimAliases(a.child))( exprId = a.exprId, qualifier = a.qualifier, - explicitMetadata = Some(a.metadata)) + explicitMetadata = Some(a.metadata), + deniedMetadataKeys = a.deniedMetadataKeys) case a: MultiAlias => a.copy(child = trimAliases(a.child)) case other => trimAliases(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index f0f92e2d93..cb59fbda2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -31,10 +31,16 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * Trait to indicate the expression doesn't have any side effects. This can be used - * to indicate its ok to optimize it out under certain circumstances. + * Trait to indicate the expression does not throw an exception by itself when they are evaluated. + * For example, UDFs, [[AssertTrue]], etc can throw an exception when they are executed. + * In such case, it is necessary to call [[Expression.eval]], and the optimization rule should + * not ignore it. + * + * This trait can be used in an optimization rule such as + * [[org.apache.spark.sql.catalyst.optimizer.ConstantFolding]] to fold the expressions that + * do not need to execute, for example, `size(array(c0, c1, c2))`. */ -trait NoSideEffect +trait NoThrow /** * Returns an Array containing the evaluation of all children expressions. @@ -48,7 +54,7 @@ trait NoSideEffect """, since = "1.1.0") case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) - extends Expression with NoSideEffect { + extends Expression with NoThrow { def this(children: Seq[Expression]) = { this(children, SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) @@ -166,7 +172,7 @@ private [sql] object GenArrayData { """, since = "2.0.0") case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) - extends Expression with NoSideEffect{ + extends Expression with NoThrow { def this(children: Seq[Expression]) = { this(children, SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE)) @@ -385,7 +391,7 @@ object CreateStruct { """, since = "1.5.0") // scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends Expression with NoSideEffect { +case class CreateNamedStruct(children: Seq[Expression]) extends Expression with NoThrow { lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 767650d022..ef247efbe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -394,7 +394,7 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes { val keyJavaType = CodeGenerator.javaType(keyType) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val keyNotFoundBranch = if (failOnError) { - s"""throw new NoSuchElementException("Key " + $eval2 + " does not exist.");""" + s"""throw new java.util.NoSuchElementException("Key " + $eval2 + " does not exist.");""" } else { s"${ev.isNull} = true;" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index bbf1e4657f..424887a13c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1789,8 +1789,11 @@ private case class GetTimestamp( """, group = "datetime_funcs", since = "3.0.0") -case class MakeDate(year: Expression, month: Expression, day: Expression, - failOnError: Boolean = SQLConf.get.ansiEnabled) +case class MakeDate( + year: Expression, + month: Expression, + day: Expression, + failOnError: Boolean = SQLConf.get.ansiEnabled) extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(year: Expression, month: Expression, day: Expression) = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 6219457bba..27067e17e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -166,13 +166,13 @@ case class MakeInterval( extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant { def this( - years: Expression, - months: Expression, - weeks: Expression, - days: Expression, - hours: Expression, - mins: Expression, - sec: Expression) = { + years: Expression, + months: Expression, + weeks: Expression, + days: Expression, + hours: Expression, + mins: Expression, + sec: Expression) = { this(years, months, weeks, days, hours, mins, sec, SQLConf.get.ansiEnabled) } def this( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 2abd9d7bb4..22aabd3c6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -143,11 +143,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn * fully qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. * @param explicitMetadata Explicit metadata associated with this alias that overwrites child's. + * @param deniedMetadataKeys Keys of metadata entries that are supposed to be removed when + * inheriting the metadata from the child. */ case class Alias(child: Expression, name: String)( val exprId: ExprId = NamedExpression.newExprId, val qualifier: Seq[String] = Seq.empty, - val explicitMetadata: Option[Metadata] = None) + val explicitMetadata: Option[Metadata] = None, + val deniedMetadataKeys: Seq[String] = Seq.empty) extends UnaryExpression with NamedExpression { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) @@ -167,7 +170,11 @@ case class Alias(child: Expression, name: String)( override def metadata: Metadata = { explicitMetadata.getOrElse { child match { - case named: NamedExpression => named.metadata + case named: NamedExpression => + val builder = new MetadataBuilder().withMetadata(named.metadata) + deniedMetadataKeys.foreach(builder.remove) + builder.build() + case _ => Metadata.empty } } @@ -194,7 +201,7 @@ case class Alias(child: Expression, name: String)( override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix$delaySuffix" override protected final def otherCopyArgs: Seq[AnyRef] = { - exprId :: qualifier :: explicitMetadata :: Nil + exprId :: qualifier :: explicitMetadata :: deniedMetadataKeys :: Nil } override def hashCode(): Int = { @@ -205,7 +212,7 @@ case class Alias(child: Expression, name: String)( override def equals(other: Any): Boolean = other match { case a: Alias => name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier && - explicitMetadata == a.explicitMetadata + explicitMetadata == a.explicitMetadata && deniedMetadataKeys == a.deniedMetadataKeys case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4725f49340..1b1e2ad71e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -45,7 +45,7 @@ object ConstantFolding extends Rule[LogicalPlan] { private def hasNoSideEffect(e: Expression): Boolean = e match { case _: Attribute => true case _: Literal => true - case _: NoSideEffect => e.children.forall(hasNoSideEffect) + case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect) case _ => false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 842c8f3243..70eb391ad6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils /** * A few helper functions for expression evaluation testing. Mixin this trait to use them. @@ -160,9 +159,14 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB expectedErrMsg: String): Unit = { def checkException(eval: => Unit, testMode: String): Unit = { + val modes = Seq(CodegenObjectFactoryMode.CODEGEN_ONLY, CodegenObjectFactoryMode.NO_CODEGEN) withClue(s"($testMode)") { val errMsg = intercept[T] { - eval + for (fallbackMode <- modes) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallbackMode.toString) { + eval + } + } }.getMessage if (errMsg == null) { if (expectedErrMsg != null) { @@ -192,22 +196,6 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB expression.eval(inputRow) } - protected def generateProject( - generator: => Projection, - expression: Expression): Projection = { - try { - generator - } catch { - case e: Throwable => - fail( - s""" - |Code generation of $expression failed: - |$e - |${Utils.exceptionString(e)} - """.stripMargin) - } - } - protected def checkEvaluationWithoutCodegen( expression: Expression, expected: Any, @@ -244,9 +232,7 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB protected def evaluateWithMutableProjection( expression: => Expression, inputRow: InternalRow = EmptyRow): Any = { - val plan = generateProject( - MutableProjection.create(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) + val plan = MutableProjection.create(Alias(expression, s"Optimized($expression)")() :: Nil) plan.initialize(0) plan(inputRow).get(0, expression.dataType) @@ -292,11 +278,9 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB // SPARK-16489 Explicitly doing code generation twice so code gen will fail if // some expression is reusing variable names across different instances. // This behavior is tested in ExpressionEvalHelperSuite. - val plan = generateProject( - UnsafeProjection.create( - Alias(expression, s"Optimized($expression)1")() :: - Alias(expression, s"Optimized($expression)2")() :: Nil), - expression) + val plan = UnsafeProjection.create( + Alias(expression, s"Optimized($expression)1")() :: + Alias(expression, s"Optimized($expression)2")() :: Nil) plan.initialize(0) plan(inputRow) @@ -319,16 +303,13 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB checkEvaluationWithMutableProjection(expression, expected) checkEvaluationWithOptimization(expression, expected) - var plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) + var plan: Projection = + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) plan.initialize(0) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected, expression)) - plan = generateProject( - GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) + plan = GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) plan.initialize(0) val ref = new BoundReference(0, expression.dataType, nullable = true) actual = GenerateSafeProjection.generate(ref :: Nil)(plan(inputRow)).get(0, expression.dataType) @@ -456,9 +437,7 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB } } - val plan = generateProject( - GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil), - expr) + val plan = GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil) val (codegen, codegenExc) = try { (Some(plan(inputRow).get(0, expr.dataType)), None) } catch { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 54ef9641be..3cc50da389 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.types.{DataType, IntegerType, MapType} */ class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper { - test("SPARK-16489 checkEvaluation should fail if expression reuses variable names") { - val e = intercept[RuntimeException] { checkEvaluation(BadCodegenExpression(), 10) } + test("SPARK-16489: checkEvaluation should fail if expression reuses variable names") { + val e = intercept[Exception] { checkEvaluation(BadCodegenExpression(), 10) } assert(e.getMessage.contains("some_variable")) } @@ -43,6 +43,12 @@ class ExpressionEvalHelperSuite extends SparkFunSuite with ExpressionEvalHelper } assert(e.getMessage.contains("and exprNullable was")) } + + test("SPARK-33619: make sure checkExceptionInExpression work as expected") { + checkExceptionInExpression[Exception]( + BadCodegenAndEvalExpression(), + "Cannot determine simple type name \"NoSuchElementException\"") + } } /** @@ -76,3 +82,18 @@ case class MapIncorrectDataTypeExpression() extends LeafExpression with CodegenF // since values includes null, valueContainsNull must be true override def dataType: DataType = MapType(IntegerType, IntegerType, valueContainsNull = false) } + +case class BadCodegenAndEvalExpression() extends LeafExpression { + override def nullable: Boolean = false + override def eval(input: InternalRow): Any = + throw new Exception("Cannot determine simple type name \"NoSuchElementException\"") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // it should be java.util.NoSuchElementException in generated code. + ev.copy(code = + code""" + |int ${ev.value} = 10; + |throw new NoSuchElementException("compile failed!"); + """.stripMargin) + } + override def dataType: DataType = IntegerType +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 5c73a91de4..950637c958 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -217,15 +217,15 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("ANSI mode: make interval") { def check( - years: Int = 0, - months: Int = 0, - weeks: Int = 0, - days: Int = 0, - hours: Int = 0, - minutes: Int = 0, - seconds: Int = 0, - millis: Int = 0, - micros: Int = 0): Unit = { + years: Int = 0, + months: Int = 0, + weeks: Int = 0, + days: Int = 0, + hours: Int = 0, + minutes: Int = 0, + seconds: Int = 0, + millis: Int = 0, + micros: Int = 0): Unit = { val secFrac = DateTimeTestUtils.secFrac(seconds, millis, micros) val intervalExpr = MakeInterval(Literal(years), Literal(months), Literal(weeks), Literal(days), Literal(hours), Literal(minutes), @@ -238,15 +238,15 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } def checkException( - years: Int = 0, - months: Int = 0, - weeks: Int = 0, - days: Int = 0, - hours: Int = 0, - minutes: Int = 0, - seconds: Int = 0, - millis: Int = 0, - micros: Int = 0): Unit = { + years: Int = 0, + months: Int = 0, + weeks: Int = 0, + days: Int = 0, + hours: Int = 0, + minutes: Int = 0, + seconds: Int = 0, + millis: Int = 0, + micros: Int = 0): Unit = { val secFrac = DateTimeTestUtils.secFrac(seconds, millis, micros) val intervalExpr = MakeInterval(Literal(years), Literal(months), Literal(weeks), Literal(days), Literal(hours), Literal(minutes), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index b4096f21be..6d09e28362 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -138,9 +138,8 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { - val plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) + val plan = + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil) val actual = plan(inputRow).get(0, expression.dataType) if (!actual.asInstanceOf[Double].isNaN) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index fd9b58a7a0..ae644c1110 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -264,7 +264,7 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("SPARK-33544: Constant folding test with sideaffects") { + test("SPARK-33544: Constant folding test with side effects") { val originalQuery = testRelation .select('a) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala index c6fa1bd6e4..93a1d414ed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala @@ -90,13 +90,13 @@ class InferFiltersFromGenerateSuite extends PlanTest { Seq(Explode(_), PosExplode(_)).foreach { f => val createArrayExplode = f(CreateArray(Seq('c1))) - test("Don't infer filters from CreateArray " + createArrayExplode) { + test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) { val originalQuery = testRelation.generate(createArrayExplode).analyze val optimized = OptimizeInferAndConstantFold.execute(originalQuery) comparePlans(optimized, originalQuery) } val createMapExplode = f(CreateMap(Seq('c1, 'c2))) - test("Don't infer filters from CreateMap " + createMapExplode) { + test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) { val originalQuery = testRelation.generate(createMapExplode).analyze val optimized = OptimizeInferAndConstantFold.execute(originalQuery) comparePlans(optimized, originalQuery) @@ -105,7 +105,7 @@ class InferFiltersFromGenerateSuite extends PlanTest { Seq(Inline(_)).foreach { f => val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1))))) - test("Don't infer filters from CreateArray " + createArrayStructExplode) { + test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) { val originalQuery = testRelation.generate(createArrayStructExplode).analyze val optimized = OptimizeInferAndConstantFold.execute(originalQuery) comparePlans(optimized, originalQuery) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 95134d9111..86ba813402 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1164,7 +1164,10 @@ class Column(val expr: Expression) extends Logging { * @since 2.0.0 */ def name(alias: String): Column = withExpr { - Alias(normalizedExpr(), alias)() + // SPARK-33536: The Alias is no longer a column reference after converting to an attribute. + // These denied metadata keys are used to strip the column reference related metadata for + // the Alias. So it won't be caught as a column reference in DetectAmbiguousSelfJoin. + Alias(expr, alias)(deniedMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2c38a65ac2..0716043bcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -231,7 +231,8 @@ class Dataset[T] private[sql]( case _ => queryExecution.analyzed } - if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { + if (sparkSession.sessionState.conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) && + plan.getTagValue(Dataset.DATASET_ID_TAG).isEmpty) { plan.setTagValue(Dataset.DATASET_ID_TAG, id) } plan @@ -259,15 +260,16 @@ class Dataset[T] private[sql]( private[sql] def resolve(colName: String): NamedExpression = { val resolver = sparkSession.sessionState.analyzer.resolver queryExecution.analyzed.resolveQuoted(colName, resolver) - .getOrElse { - val fields = schema.fieldNames - val extraMsg = if (fields.exists(resolver(_, colName))) { - s"; did you mean to quote the `$colName` column?" - } else "" - val fieldsStr = fields.mkString(", ") - val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}""" - throw new AnalysisException(errorMsg) - } + .getOrElse(throw resolveException(colName, schema.fieldNames)) + } + + private def resolveException(colName: String, fields: Array[String]): AnalysisException = { + val extraMsg = if (fields.exists(sparkSession.sessionState.analyzer.resolver(_, colName))) { + s"; did you mean to quote the `$colName` column?" + } else "" + val fieldsStr = fields.mkString(", ") + val errorMsg = s"""Cannot resolve column name "$colName" among (${fieldsStr})${extraMsg}""" + new AnalysisException(errorMsg) } private[sql] def numericColumns: Seq[Expression] = { @@ -1083,8 +1085,8 @@ class Dataset[T] private[sql]( } // If left/right have no output set intersection, return the plan. - val lanalyzed = withPlan(this.logicalPlan).queryExecution.analyzed - val ranalyzed = withPlan(right.logicalPlan).queryExecution.analyzed + val lanalyzed = this.queryExecution.analyzed + val ranalyzed = right.queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { return withPlan(plan) } @@ -1092,17 +1094,22 @@ class Dataset[T] private[sql]( // Otherwise, find the trivially true predicates and automatically resolves them to both sides. // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. + val resolver = sparkSession.sessionState.analyzer.resolver val cond = plan.condition.map { _.transform { case catalyst.expressions.EqualTo(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => catalyst.expressions.EqualTo( - withPlan(plan.left).resolve(a.name), - withPlan(plan.right).resolve(b.name)) + plan.left.resolveQuoted(a.name, resolver) + .getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)), + plan.right.resolveQuoted(b.name, resolver) + .getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames))) case catalyst.expressions.EqualNullSafe(a: AttributeReference, b: AttributeReference) if a.sameRef(b) => catalyst.expressions.EqualNullSafe( - withPlan(plan.left).resolve(a.name), - withPlan(plan.right).resolve(b.name)) + plan.left.resolveQuoted(a.name, resolver) + .getOrElse(throw resolveException(a.name, plan.left.schema.fieldNames)), + plan.right.resolveQuoted(b.name, resolver) + .getOrElse(throw resolveException(b.name, plan.right.schema.fieldNames))) }} withPlan { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 3b3b54f75d..50846d9d12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.SQLTestData.TestData class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -219,4 +220,32 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { Seq((1, 2), (1, 2), (2, 4), (2, 4)).map(Row.fromTuple)) } } + + test("SPARK-33071/SPARK-33536: Avoid changing dataset_id of LogicalPlan in join() " + + "to not break DetectAmbiguousSelfJoin") { + val emp1 = Seq[TestData]( + TestData(1, "sales"), + TestData(2, "personnel"), + TestData(3, "develop"), + TestData(4, "IT")).toDS() + val emp2 = Seq[TestData]( + TestData(1, "sales"), + TestData(2, "personnel"), + TestData(3, "develop")).toDS() + val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*")) + assertAmbiguousSelfJoin(emp1.join(emp3, emp1.col("key") === emp3.col("key"), + "left_outer").select(emp1.col("*"), emp3.col("key").as("e2"))) + } + + test("df.show() should also not change dataset_id of LogicalPlan") { + val df = Seq[TestData]( + TestData(1, "sales"), + TestData(2, "personnel"), + TestData(3, "develop"), + TestData(4, "IT")).toDF() + val ds_id1 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG) + df.show(0) + val ds_id2 = df.logicalPlan.getTagValue(Dataset.DATASET_ID_TAG) + assert(ds_id1 === ds_id2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 12abd31b99..f02d2041dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -573,8 +573,9 @@ class ColumnarBoundReference(ordinal: Int, dataType: DataType, nullable: Boolean class ColumnarAlias(child: ColumnarExpression, name: String)( override val exprId: ExprId = NamedExpression.newExprId, override val qualifier: Seq[String] = Seq.empty, - override val explicitMetadata: Option[Metadata] = None) - extends Alias(child, name)(exprId, qualifier, explicitMetadata) + override val explicitMetadata: Option[Metadata] = None, + override val deniedMetadataKeys: Seq[String] = Seq.empty) + extends Alias(child, name)(exprId, qualifier, explicitMetadata, deniedMetadataKeys) with ColumnarExpression { override def columnarEval(batch: ColumnarBatch): Any = child.columnarEval(batch) @@ -711,7 +712,7 @@ case class PreRuleReplaceAddWithBrokenVersion() extends Rule[SparkPlan] { def replaceWithColumnarExpression(exp: Expression): ColumnarExpression = exp match { case a: Alias => new ColumnarAlias(replaceWithColumnarExpression(a.child), - a.name)(a.exprId, a.qualifier, a.explicitMetadata) + a.name)(a.exprId, a.qualifier, a.explicitMetadata, a.deniedMetadataKeys) case att: AttributeReference => new ColumnarAttributeReference(att.name, att.dataType, att.nullable, att.metadata)(att.exprId, att.qualifier)