Skip to content

Commit

Permalink
[SPARK-33071][SPARK-33536][SQL] Avoid changing dataset_id of LogicalP…
Browse files Browse the repository at this point in the history
…lan in join() to not break DetectAmbiguousSelfJoin

### What changes were proposed in this pull request?

Currently, `join()` uses `withPlan(logicalPlan)` for convenient to call some Dataset functions. But it leads to the `dataset_id` inconsistent between the `logicalPlan` and the original `Dataset`(because `withPlan(logicalPlan)` will create a new Dataset with the new id and reset the `dataset_id` with the new id of the `logicalPlan`). As a result, it breaks the rule `DetectAmbiguousSelfJoin`.

In this PR, we propose to drop the usage of `withPlan` but use the `logicalPlan` directly so its `dataset_id` doesn't change.

Besides, this PR also removes related metadata (`DATASET_ID_KEY`,  `COL_POS_KEY`) when an `Alias` tries to construct its own metadata. Because the `Alias` is no longer a reference column after converting to an `Attribute`.  To achieve that, we add a new field, `deniedMetadataKeys`, to indicate the metadata that needs to be removed.

### Why are the changes needed?

For the query below, it returns the wrong result while it should throws ambiguous self join exception instead:

```scala
val emp1 = Seq[TestData](
  TestData(1, "sales"),
  TestData(2, "personnel"),
  TestData(3, "develop"),
  TestData(4, "IT")).toDS()
val emp2 = Seq[TestData](
  TestData(1, "sales"),
  TestData(2, "personnel"),
  TestData(3, "develop")).toDS()
val emp3 = emp1.join(emp2, emp1("key") === emp2("key")).select(emp1("*"))
emp1.join(emp3, emp1.col("key") === emp3.col("key"), "left_outer")
  .select(emp1.col("*"), emp3.col("key").as("e2")).show()

// wrong result
+---+---------+---+
|key|    value| e2|
+---+---------+---+
|  1|    sales|  1|
|  2|personnel|  2|
|  3|  develop|  3|
|  4|       IT|  4|
+---+---------+---+
```
This PR fixes the wrong behaviour.

### Does this PR introduce _any_ user-facing change?

Yes, users hit the exception instead of the wrong result after this PR.

### How was this patch tested?

Added a new unit test.

Closes #30488 from Ngone51/fix-self-join.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
a0x8o committed Dec 2, 2020
1 parent 50fc2f4 commit 8cef24b
Show file tree
Hide file tree
Showing 23 changed files with 207 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite =>
runSparkApplicationAndVerifyCompletion(
appResource = PYSPARK_DECOMISSIONING,
mainClass = "",
expectedLogOnCompletion = Seq(
expectedDriverLogOnCompletion = Seq(
"Finished waiting, stopping Spark",
"Decommission executors",
"Final accumulator value is: 100"),
Expand Down Expand Up @@ -69,7 +69,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite =>
runSparkApplicationAndVerifyCompletion(
appResource = PYSPARK_DECOMISSIONING_CLEANUP,
mainClass = "",
expectedLogOnCompletion = Seq(
expectedDriverLogOnCompletion = Seq(
"Finished waiting, stopping Spark",
"Decommission executors"),
appArgs = Array.empty[String],
Expand Down Expand Up @@ -104,7 +104,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite =>
runSparkApplicationAndVerifyCompletion(
appResource = PYSPARK_SCALE,
mainClass = "",
expectedLogOnCompletion = Seq(
expectedDriverLogOnCompletion = Seq(
"Finished waiting, stopping Spark",
"Decommission executors"),
appArgs = Array.empty[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ private[spark] trait DepsTestsSuite { k8sSuite: KubernetesSuite =>
runSparkApplicationAndVerifyCompletion(
appResource = pySparkFiles,
mainClass = "",
expectedLogOnCompletion = Seq(
expectedDriverLogOnCompletion = Seq(
"Python runtime version check is: True",
"Python environment version check is: True",
"Python runtime version check for executor is: True"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class KubernetesSuite extends SparkFunSuite
appResource,
SPARK_PI_MAIN_CLASS,
Seq("Pi is roughly 3"),
Seq(),
appArgs,
driverPodChecker,
executorPodChecker,
Expand All @@ -192,6 +193,7 @@ class KubernetesSuite extends SparkFunSuite
SPARK_DFS_READ_WRITE_TEST,
Seq(s"Success! Local Word Count $wordCount and " +
s"DFS Word Count $wordCount agree."),
Seq(),
appArgs,
driverPodChecker,
executorPodChecker,
Expand All @@ -212,6 +214,7 @@ class KubernetesSuite extends SparkFunSuite
appResource,
SPARK_REMOTE_MAIN_CLASS,
Seq(s"Mounting of ${appArgs.head} was true"),
Seq(),
appArgs,
driverPodChecker,
executorPodChecker,
Expand Down Expand Up @@ -261,7 +264,8 @@ class KubernetesSuite extends SparkFunSuite
protected def runSparkApplicationAndVerifyCompletion(
appResource: String,
mainClass: String,
expectedLogOnCompletion: Seq[String],
expectedDriverLogOnCompletion: Seq[String],
expectedExecutorLogOnCompletion: Seq[String] = Seq(),
appArgs: Array[String],
driverPodChecker: Pod => Unit,
executorPodChecker: Pod => Unit,
Expand Down Expand Up @@ -374,7 +378,6 @@ class KubernetesSuite extends SparkFunSuite
.list()
.getItems
.get(0)

driverPodChecker(driverPod)

// If we're testing decommissioning we an executors, but we should have an executor
Expand All @@ -383,14 +386,35 @@ class KubernetesSuite extends SparkFunSuite
execPods.values.nonEmpty should be (true)
}
execPods.values.foreach(executorPodChecker(_))

val execPod: Option[Pod] = if (expectedExecutorLogOnCompletion.nonEmpty) {
Some(kubernetesTestComponents.kubernetesClient
.pods()
.withLabel("spark-app-locator", appLocator)
.withLabel("spark-role", "executor")
.list()
.getItems
.get(0))
} else {
None
}

Eventually.eventually(patienceTimeout, patienceInterval) {
expectedLogOnCompletion.foreach { e =>
expectedDriverLogOnCompletion.foreach { e =>
assert(kubernetesTestComponents.kubernetesClient
.pods()
.withName(driverPod.getMetadata.getName)
.getLog
.contains(e),
s"The application did not complete, did not find str ${e}")
s"The application did not complete, driver log did not contain str ${e}")
}
expectedExecutorLogOnCompletion.foreach { e =>
assert(kubernetesTestComponents.kubernetesClient
.pods()
.withName(execPod.get.getMetadata.getName)
.getLog
.contains(e),
s"The application did not complete, executor log did not contain str ${e}")
}
}
execWatcher.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite =>
runSparkApplicationAndVerifyCompletion(
appResource = PYSPARK_PI,
mainClass = "",
expectedLogOnCompletion = Seq("Pi is roughly 3"),
expectedDriverLogOnCompletion = Seq("Pi is roughly 3"),
appArgs = Array("5"),
driverPodChecker = doBasicDriverPyPodCheck,
executorPodChecker = doBasicExecutorPyPodCheck,
Expand All @@ -41,7 +41,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite =>
runSparkApplicationAndVerifyCompletion(
appResource = PYSPARK_FILES,
mainClass = "",
expectedLogOnCompletion = Seq(
expectedDriverLogOnCompletion = Seq(
"Python runtime version check is: True",
"Python environment version check is: True",
"Python runtime version check for executor is: True"),
Expand All @@ -61,7 +61,7 @@ private[spark] trait PythonTestsSuite { k8sSuite: KubernetesSuite =>
runSparkApplicationAndVerifyCompletion(
appResource = PYSPARK_MEMORY_CHECK,
mainClass = "",
expectedLogOnCompletion = Seq(
expectedDriverLogOnCompletion = Seq(
"PySpark Worker Memory Check is: True"),
appArgs = Array(s"$additionalMemoryInBytes"),
driverPodChecker = doDriverMemoryCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ private[spark] trait RTestsSuite { k8sSuite: KubernetesSuite =>
runSparkApplicationAndVerifyCompletion(
appResource = SPARK_R_DATAFRAME_TEST,
mainClass = "",
expectedLogOnCompletion = Seq("name: string (nullable = true)", "1 Justin"),
expectedDriverLogOnCompletion = Seq("name: string (nullable = true)", "1 Justin"),
appArgs = Array.empty[String],
driverPodChecker = doBasicDriverRPodCheck,
executorPodChecker = doBasicExecutorRPodCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@
*/
package org.apache.spark.deploy.k8s.integrationtest

import java.io.{BufferedWriter, File, FileWriter}
import java.io.File
import java.net.URL
import java.nio.file.Files

import scala.io.{BufferedSource, Source}

import io.fabric8.kubernetes.api.model._

import org.apache.spark.internal.config
import scala.io.Source

private[spark] trait SparkConfPropagateSuite { k8sSuite: KubernetesSuite =>
import KubernetesSuite.{k8sTestTag, SPARK_PI_MAIN_CLASS}
Expand All @@ -38,18 +35,21 @@ private[spark] trait SparkConfPropagateSuite { k8sSuite: KubernetesSuite =>
val logConfFilePath = s"${sparkHomeDir.toFile}/conf/log4j.properties"

try {
val writer = new BufferedWriter(new FileWriter(logConfFilePath))
writer.write(content)
writer.close()
Files.write(new File(logConfFilePath).toPath, content.getBytes)

sparkAppConf.set("spark.driver.extraJavaOptions", "-Dlog4j.debug")
sparkAppConf.set("spark.executor.extraJavaOptions", "-Dlog4j.debug")

val log4jExpectedLog =
s"log4j: Reading configuration from URL file:/opt/spark/conf/log4j.properties"

runSparkApplicationAndVerifyCompletion(
appResource = containerLocalSparkDistroExamplesJar,
mainClass = SPARK_PI_MAIN_CLASS,
expectedLogOnCompletion = (Seq("DEBUG",
s"log4j: Reading configuration from URL file:/opt/spark/conf/log4j.properties",
expectedDriverLogOnCompletion = (Seq("DEBUG",
log4jExpectedLog,
"Pi is roughly 3")),
expectedExecutorLogOnCompletion = Seq(log4jExpectedLog),
appArgs = Array.empty[String],
driverPodChecker = doBasicDriverPodCheck,
executorPodChecker = doBasicExecutorPodCheck,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ trait AliasHelper {
a.copy(child = trimAliases(a.child))(
exprId = a.exprId,
qualifier = a.qualifier,
explicitMetadata = Some(a.metadata))
explicitMetadata = Some(a.metadata),
deniedMetadataKeys = a.deniedMetadataKeys)
case a: MultiAlias =>
a.copy(child = trimAliases(a.child))
case other => trimAliases(other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,16 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Trait to indicate the expression doesn't have any side effects. This can be used
* to indicate its ok to optimize it out under certain circumstances.
* Trait to indicate the expression does not throw an exception by itself when they are evaluated.
* For example, UDFs, [[AssertTrue]], etc can throw an exception when they are executed.
* In such case, it is necessary to call [[Expression.eval]], and the optimization rule should
* not ignore it.
*
* This trait can be used in an optimization rule such as
* [[org.apache.spark.sql.catalyst.optimizer.ConstantFolding]] to fold the expressions that
* do not need to execute, for example, `size(array(c0, c1, c2))`.
*/
trait NoSideEffect
trait NoThrow

/**
* Returns an Array containing the evaluation of all children expressions.
Expand All @@ -48,7 +54,7 @@ trait NoSideEffect
""",
since = "1.1.0")
case class CreateArray(children: Seq[Expression], useStringTypeWhenEmpty: Boolean)
extends Expression with NoSideEffect {
extends Expression with NoThrow {

def this(children: Seq[Expression]) = {
this(children, SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE))
Expand Down Expand Up @@ -166,7 +172,7 @@ private [sql] object GenArrayData {
""",
since = "2.0.0")
case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean)
extends Expression with NoSideEffect{
extends Expression with NoThrow {

def this(children: Seq[Expression]) = {
this(children, SQLConf.get.getConf(SQLConf.LEGACY_CREATE_EMPTY_COLLECTION_USING_STRING_TYPE))
Expand Down Expand Up @@ -385,7 +391,7 @@ object CreateStruct {
""",
since = "1.5.0")
// scalastyle:on line.size.limit
case class CreateNamedStruct(children: Seq[Expression]) extends Expression with NoSideEffect {
case class CreateNamedStruct(children: Seq[Expression]) extends Expression with NoThrow {
lazy val (nameExprs, valExprs) = children.grouped(2).map {
case Seq(name, value) => (name, value)
}.toList.unzip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ trait GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
val keyJavaType = CodeGenerator.javaType(keyType)
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val keyNotFoundBranch = if (failOnError) {
s"""throw new NoSuchElementException("Key " + $eval2 + " does not exist.");"""
s"""throw new java.util.NoSuchElementException("Key " + $eval2 + " does not exist.");"""
} else {
s"${ev.isNull} = true;"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1789,8 +1789,11 @@ private case class GetTimestamp(
""",
group = "datetime_funcs",
since = "3.0.0")
case class MakeDate(year: Expression, month: Expression, day: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
case class MakeDate(
year: Expression,
month: Expression,
day: Expression,
failOnError: Boolean = SQLConf.get.ansiEnabled)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {

def this(year: Expression, month: Expression, day: Expression) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ case class MakeInterval(
extends SeptenaryExpression with ImplicitCastInputTypes with NullIntolerant {

def this(
years: Expression,
months: Expression,
weeks: Expression,
days: Expression,
hours: Expression,
mins: Expression,
sec: Expression) = {
years: Expression,
months: Expression,
weeks: Expression,
days: Expression,
hours: Expression,
mins: Expression,
sec: Expression) = {
this(years, months, weeks, days, hours, mins, sec, SQLConf.get.ansiEnabled)
}
def this(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,14 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
* fully qualified way. Consider the examples tableName.name, subQueryAlias.name.
* tableName and subQueryAlias are possible qualifiers.
* @param explicitMetadata Explicit metadata associated with this alias that overwrites child's.
* @param deniedMetadataKeys Keys of metadata entries that are supposed to be removed when
* inheriting the metadata from the child.
*/
case class Alias(child: Expression, name: String)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifier: Seq[String] = Seq.empty,
val explicitMetadata: Option[Metadata] = None)
val explicitMetadata: Option[Metadata] = None,
val deniedMetadataKeys: Seq[String] = Seq.empty)
extends UnaryExpression with NamedExpression {

// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
Expand All @@ -167,7 +170,11 @@ case class Alias(child: Expression, name: String)(
override def metadata: Metadata = {
explicitMetadata.getOrElse {
child match {
case named: NamedExpression => named.metadata
case named: NamedExpression =>
val builder = new MetadataBuilder().withMetadata(named.metadata)
deniedMetadataKeys.foreach(builder.remove)
builder.build()

case _ => Metadata.empty
}
}
Expand All @@ -194,7 +201,7 @@ case class Alias(child: Expression, name: String)(
override def toString: String = s"$child AS $name#${exprId.id}$typeSuffix$delaySuffix"

override protected final def otherCopyArgs: Seq[AnyRef] = {
exprId :: qualifier :: explicitMetadata :: Nil
exprId :: qualifier :: explicitMetadata :: deniedMetadataKeys :: Nil
}

override def hashCode(): Int = {
Expand All @@ -205,7 +212,7 @@ case class Alias(child: Expression, name: String)(
override def equals(other: Any): Boolean = other match {
case a: Alias =>
name == a.name && exprId == a.exprId && child == a.child && qualifier == a.qualifier &&
explicitMetadata == a.explicitMetadata
explicitMetadata == a.explicitMetadata && deniedMetadataKeys == a.deniedMetadataKeys
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
private def hasNoSideEffect(e: Expression): Boolean = e match {
case _: Attribute => true
case _: Literal => true
case _: NoSideEffect => e.children.forall(hasNoSideEffect)
case _: NoThrow if e.deterministic => e.children.forall(hasNoSideEffect)
case _ => false
}

Expand Down
Loading

0 comments on commit 8cef24b

Please sign in to comment.