Skip to content

Commit

Permalink
[SPARK-15425][SQL] Disallow cross joins by default
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

In order to prevent users from inadvertently writing queries with cartesian joins, this patch introduces a new conf `spark.sql.crossJoin.enabled` (set to `false` by default) that if not set, results in a `SparkException` if the query contains one or more cartesian products.

## How was this patch tested?

Added a test to verify the new behavior in `JoinSuite`. Additionally, `SQLQuerySuite` and `SQLMetricsSuite` were modified to explicitly enable cartesian products.

Author: Sameer Agarwal <sameer@databricks.com>

Closes #13209 from sameeragarwal/disallow-cartesian.
  • Loading branch information
sameeragarwal authored and rxin committed May 23, 2016
1 parent fc44b69 commit dafcb05
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
// This join could be very slow or OOM
joins.BroadcastNestedLoopJoinExec(
planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
planLater(left), planLater(right), buildSide, joinType, condition,
withinBroadcastThreshold = false) :: Nil

// --- Cases where this strategy does not apply ---------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.collection.{BitSet, CompactBuffer}

case class BroadcastNestedLoopJoinExec(
left: SparkPlan,
right: SparkPlan,
buildSide: BuildSide,
joinType: JoinType,
condition: Option[Expression]) extends BinaryExecNode {
condition: Option[Expression],
withinBroadcastThreshold: Boolean = true) extends BinaryExecNode {

override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
Expand Down Expand Up @@ -337,6 +340,15 @@ case class BroadcastNestedLoopJoinExec(
)
}

protected override def doPrepare(): Unit = {
if (!withinBroadcastThreshold && !sqlContext.conf.crossJoinEnabled) {
throw new AnalysisException("Both sides of this join are outside the broadcasting " +
"threshold and computing it could be prohibitively expensive. To explicitly enable it, " +
s"please set ${SQLConf.CROSS_JOINS_ENABLED.key} = true")
}
super.doPrepare()
}

protected override def doExecute(): RDD[InternalRow] = {
val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.joins

import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter

Expand Down Expand Up @@ -88,6 +90,15 @@ case class CartesianProductExec(
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

protected override def doPrepare(): Unit = {
if (!sqlContext.conf.crossJoinEnabled) {
throw new AnalysisException("Cartesian joins could be prohibitively expensive and are " +
"disabled by default. To explicitly enable them, please set " +
s"${SQLConf.CROSS_JOINS_ENABLED.key} = true")
}
super.doPrepare()
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val CROSS_JOINS_ENABLED = SQLConfigBuilder("spark.sql.crossJoin.enabled")
.doc("When false, we will throw an error if a query contains a cross join")
.booleanConf
.createWithDefault(false)

val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal")
.doc("When true, the ordinal numbers are treated as the position in the select list. " +
"When false, the ordinal numbers in order/sort By clause are ignored.")
"When false, the ordinal numbers in order/sort by clause are ignored.")
.booleanConf
.createWithDefault(true)

Expand Down Expand Up @@ -622,6 +627,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {

def bucketingEnabled: Boolean = getConf(SQLConf.BUCKETING_ENABLED)

def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED)

// Do not use a value larger than 4000 as the default value of this property.
// See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
def schemaStringLengthThreshold: Int = getConf(SCHEMA_STRING_LENGTH_THRESHOLD)
Expand Down
31 changes: 23 additions & 8 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
test("join operator selection") {
spark.cacheManager.clearCache()

withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0",
SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
Seq(
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
classOf[SortMergeJoinExec]),
Expand Down Expand Up @@ -204,13 +205,27 @@ class JoinSuite extends QueryTest with SharedSQLContext {
testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq)
}

test("cartisian product join") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
test("cartesian product join") {
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
}
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "false") {
val e = intercept[Exception] {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
}
assert(e.getMessage.contains("Cartesian joins could be prohibitively expensive and are " +
"disabled by default"))
}
}

test("left outer join") {
Expand Down
32 changes: 19 additions & 13 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
).toDF("a", "b", "c").createOrReplaceTempView("cachedData")

spark.catalog.cacheTable("cachedData")
checkAnswer(
sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
Row(0) :: Row(81) :: Nil)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer(
sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"),
Row(0) :: Row(81) :: Nil)
}
}

test("self join with aliases") {
Expand Down Expand Up @@ -435,10 +437,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}

test("left semi greater than predicate") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
Seq(Row(3, 1), Row(3, 2))
)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
Seq(Row(3, 1), Row(3, 2))
)
}
}

test("left semi greater than predicate and equal operator") {
Expand Down Expand Up @@ -824,12 +828,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}

test("cartesian product join") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer(
testData3.join(testData3),
Row(1, null, 1, null) ::
Row(1, null, 2, 2) ::
Row(2, 2, 1, null) ::
Row(2, 2, 2, 2) :: Nil)
}
}

test("left outer join") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
}

test(s"$testName using CartesianProduct") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1",
SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
CartesianProductExec(left, right, Some(condition())),
expectedAnswer.map(Row.fromTuple),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.{AccumulatorContext, JsonProtocol, Utils}

Expand Down Expand Up @@ -237,16 +238,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
test("BroadcastNestedLoopJoin metrics") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.createOrReplaceTempView("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = spark.sql(
"SELECT * FROM testData2 left JOIN testDataForJoin ON " +
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
testSparkPlanMetrics(df, 3, Map(
1L -> ("BroadcastNestedLoopJoin", Map(
"number of output rows" -> 12L)))
)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = spark.sql(
"SELECT * FROM testData2 left JOIN testDataForJoin ON " +
"testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a")
testSparkPlanMetrics(df, 3, Map(
1L -> ("BroadcastNestedLoopJoin", Map(
"number of output rows" -> 12L)))
)
}
}
}

Expand All @@ -263,17 +266,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}

test("CartesianProduct metrics") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.createOrReplaceTempView("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = spark.sql(
"SELECT * FROM testData2 JOIN testDataForJoin")
testSparkPlanMetrics(df, 1, Map(
0L -> ("CartesianProduct", Map(
"number of output rows" -> 12L)))
)
withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.createOrReplaceTempView("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
// ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0)
val df = spark.sql(
"SELECT * FROM testData2 JOIN testDataForJoin")
testSparkPlanMetrics(df, 1, Map(
0L -> ("CartesianProduct", Map("number of output rows" -> 12L)))
)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc
private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled

def testCases: Seq[(String, File)] = {
hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f)
Expand All @@ -61,6 +62,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Ensures that the plans generation use metastore relation and not OrcRelation
// Was done because SqlBuilder does not work with plans having logical relation
TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false)
// Ensures that cross joins are enabled so that we can test them
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
RuleExecutor.resetTime()
}

Expand All @@ -72,6 +75,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc)
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
TestHive.sessionState.functionRegistry.restore()

// For debugging dump some statistics about how much time was spent in various optimizer rules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.sql.internal.SQLConf

case class TestData(a: Int, b: String)

Expand All @@ -48,13 +49,17 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {

import org.apache.spark.sql.hive.test.TestHive.implicits._

private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled

override def beforeAll() {
super.beforeAll()
TestHive.setCacheTables(true)
// Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
// Add Locale setting
Locale.setDefault(Locale.US)
// Ensures that cross joins are enabled so that we can test them
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
}

override def afterAll() {
Expand All @@ -63,6 +68,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TimeZone.setDefault(originalTimeZone)
Locale.setDefault(originalLocale)
sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2")
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
} finally {
super.afterAll()
}
Expand Down

0 comments on commit dafcb05

Please sign in to comment.