Skip to content

Commit

Permalink
[SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only…
Browse files Browse the repository at this point in the history
… on the driver

## What changes were proposed in this pull request?

This is a followup of #20136 . #20136 didn't really work because in the test, we are using local backend, which shares the driver side `SparkEnv`, so `SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER` doesn't work.

This PR changes the check to `TaskContext.get != null`, and move the check to `SQLConf.get`, and fix all the places that violate this check:
* `InMemoryTableScanExec#createAndDecompressColumn` is executed inside `rdd.map`, we can't access `conf.offHeapColumnVectorEnabled` there. #21223 merged
* `DataType#sameType` may be executed in executor side, for things like json schema inference, so we can't call `conf.caseSensitiveAnalysis` there. This contributes to most of the code changes, as we need to add `caseSensitive` parameter to a lot of methods.
* `ParquetFilters` is used in the file scan function, which is executed in executor side, so we can't can't call `conf.parquetFilterPushDownDate` there. #21224 merged
* `WindowExec#createBoundOrdering` is called on executor side, so we can't use `conf.sessionLocalTimezone` there. #21225 merged
* `JsonToStructs` can be serialized to executors and evaluate, we should not call `SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)` in the body. #21226 merged

## How was this patch tested?

existing test

Author: Wenchen Fan <wenchen@databricks.com>

Closes #21190 from cloud-fan/minor.
  • Loading branch information
cloud-fan authored and HyukjinKwon committed May 11, 2018
1 parent d3c426a commit a4206d5
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -260,7 +261,9 @@ trait CheckAnalysis extends PredicateHelper {
// Check if the data types match.
dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) =>
// SPARK-18058: we shall not care about the nullability of columns
if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) {
val widerType = TypeCoercion.findWiderTypeForTwo(
dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis)
if (widerType.isEmpty) {
failAnalysis(
s"""
|${operator.nodeName} can only be performed on tables with the compatible
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas
// For each column, traverse all the values and find a common data type and nullability.
val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
val inputTypes = column.map(_.dataType)
val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion(
inputTypes, conf.caseSensitiveAnalysis)
val tpe = wideType.getOrElse {
table.failAnalysis(s"incompatible types found in column $name for inline table")
}
StructField(name, tpe, nullable = column.exists(_.nullable))
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.util.matching.Regex

import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.ByteUnit
Expand Down Expand Up @@ -107,7 +107,13 @@ object SQLConf {
* run tests in parallel. At the time this feature was implemented, this was a no-op since we
* run unit tests (that does not involve SparkSession) in serial order.
*/
def get: SQLConf = confGetter.get()()
def get: SQLConf = {
if (Utils.isTesting && TaskContext.get != null) {
// we're accessing it during task execution, fail.
throw new IllegalStateException("SQLConf should only be created and accessed on the driver.")
}
confGetter.get()()
}

val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations")
.internal()
Expand Down Expand Up @@ -1274,12 +1280,6 @@ object SQLConf {
class SQLConf extends Serializable with Logging {
import SQLConf._

if (Utils.isTesting && SparkEnv.get != null) {
// assert that we're only accessing it on the driver.
assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER,
"SQLConf should only be created and accessed on the driver.")
}

/** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */
@transient protected[spark] val settings = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, String]())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,7 @@ abstract class DataType extends AbstractDataType {
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
*/
private[spark] def sameType(other: DataType): Boolean =
if (SQLConf.get.caseSensitiveAnalysis) {
DataType.equalsIgnoreNullability(this, other)
} else {
DataType.equalsIgnoreCaseAndNullability(this, other)
}
DataType.equalsIgnoreNullability(this, other)

/**
* Returns the same data type but set all nullability fields are true
Expand Down Expand Up @@ -218,7 +214,7 @@ object DataType {
/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest {
}

private def checkWidenType(
widenFunc: (DataType, DataType) => Option[DataType],
widenFunc: (DataType, DataType, Boolean) => Option[DataType],
t1: DataType,
t2: DataType,
expected: Option[DataType],
isSymmetric: Boolean = true): Unit = {
var found = widenFunc(t1, t2)
var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis)
assert(found == expected,
s"Expected $expected as wider common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
if (isSymmetric) {
found = widenFunc(t2, t1)
found = widenFunc(t2, t1, conf.caseSensitiveAnalysis)
assert(found == expected,
s"Expected $expected as wider common type for $t2 and $t1, found $found")
}
Expand Down Expand Up @@ -524,29 +524,29 @@ class TypeCoercionSuite extends AnalysisTest {
test("cast NullType for expressions that implement ExpectsInputTypes") {
import TypeCoercionSuite._

ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeUnaryExpression(Literal.create(null, NullType)),
AnyTypeUnaryExpression(Literal.create(null, NullType)))

ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeUnaryExpression(Literal.create(null, NullType)),
NumericTypeUnaryExpression(Literal.create(null, DoubleType)))
}

test("cast NullType for binary operators") {
import TypeCoercionSuite._

ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)))

ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)),
NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType)))
}

test("coalesce casts") {
val rule = TypeCoercion.FunctionArgumentConversion
val rule = TypeCoercion.FunctionArgumentConversion(conf)

val intLit = Literal(1)
val longLit = Literal.create(1L)
Expand Down Expand Up @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("CreateArray casts") {
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
Expand All @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))

ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal("a")
Expand All @@ -626,15 +626,15 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal("a"), StringType)
:: Nil))

ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal(1)
:: Nil),
CreateArray(Literal.create(null, DecimalType(5, 3)).cast(DecimalType(13, 3))
:: Literal(1).cast(DecimalType(13, 3))
:: Nil))

ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateArray(Literal.create(null, DecimalType(5, 3))
:: Literal.create(null, DecimalType(22, 10))
:: Literal.create(null, DecimalType(38, 38))
Expand All @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest {

test("CreateMap casts") {
// type coercion for map keys
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal(1)
:: Literal("a")
:: Literal.create(2.0, FloatType)
Expand All @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal.create(2.0, FloatType), FloatType)
:: Literal("b")
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal.create(null, DecimalType(5, 3))
:: Literal("a")
:: Literal.create(2.0, FloatType)
Expand All @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal("b")
:: Nil))
// type coercion for map values
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2)
Expand All @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal(2)
:: Cast(Literal(3.0), StringType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal(1)
:: Literal.create(null, DecimalType(38, 0))
:: Literal(2)
Expand All @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38))
:: Nil))
// type coercion for both map keys and values
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2.0)
Expand All @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest {

test("greatest/least cast") {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
Expand All @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal(1L)
:: Literal(1)
:: Literal(new java.math.BigDecimal("1000000000000000000000"))
Expand All @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Cast(Literal(1), DecimalType(22, 0))
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal(1.0)
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
Expand All @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(10, 5)).cast(DoubleType)
:: Literal(1).cast(DoubleType)
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal.create(null, DecimalType(15, 0))
:: Literal.create(null, DecimalType(10, 5))
:: Literal(1)
Expand All @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest {
:: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5))
:: Literal(1).cast(DecimalType(20, 5))
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
operator(Literal.create(2L, LongType)
:: Literal(1)
:: Literal.create(null, DecimalType(10, 5))
Expand All @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("nanvl casts") {
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)),
NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType)))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType)))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)),
NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType)))
ruleTest(TypeCoercion.FunctionArgumentConversion,
ruleTest(TypeCoercion.FunctionArgumentConversion(conf),
NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)),
NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType)))
}

test("type coercion for If") {
val rule = TypeCoercion.IfCoercion
val rule = TypeCoercion.IfCoercion(conf)
val intLit = Literal(1)
val doubleLit = Literal(1.0)
val trueLit = Literal.create(true, BooleanType)
Expand Down Expand Up @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("type coercion for CaseKeyWhen") {
ruleTest(new TypeCoercion.ImplicitTypeCasts(conf),
ruleTest(TypeCoercion.ImplicitTypeCasts(conf),
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
)
ruleTest(TypeCoercion.CaseWhenCoercion,
ruleTest(TypeCoercion.CaseWhenCoercion(conf),
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
ruleTest(TypeCoercion.CaseWhenCoercion,
ruleTest(TypeCoercion.CaseWhenCoercion(conf),
CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Literal(1.2))),
Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
)
ruleTest(TypeCoercion.CaseWhenCoercion,
ruleTest(TypeCoercion.CaseWhenCoercion(conf),
CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
Expand Down Expand Up @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest {
private val timeZoneResolver = ResolveTimeZone(new SQLConf)

private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = {
timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan))
timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan))
}

test("WidenSetOperationTypes for except and intersect") {
Expand Down Expand Up @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest {

test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
"in aggregation function like sum") {
val rules = Seq(FunctionArgumentConversion, Division)
val rules = Seq(FunctionArgumentConversion(conf), Division)
// Casts Integer to Double
ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
// Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
Expand All @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest {
}

test("SPARK-17117 null type coercion in divide") {
val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf))
val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf))
val nullLit = Literal.create(null, NullType)
ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
Expand Down
Loading

0 comments on commit a4206d5

Please sign in to comment.