From e39b7d02746812e6c2eb3bc44ba3de37c12768d6 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 May 2018 16:31:25 +0800 Subject: [PATCH 1/8] Revert "[SPARK-22938][SQL][FOLLOWUP] Assert that SQLConf.get is accessed only on the driver" This reverts commit a4206d58e05ab9ed6f01fee57e18dee65cbc4efc. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 5 +- .../analysis/ResolveInlineTables.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 156 ++++++++---------- .../apache/spark/sql/internal/SQLConf.scala | 16 +- .../org/apache/spark/sql/types/DataType.scala | 8 +- .../catalyst/analysis/TypeCoercionSuite.scala | 70 ++++---- .../org/apache/spark/sql/SparkSession.scala | 21 +-- .../datasources/PartitioningUtils.scala | 5 +- .../datasources/json/JsonInferSchema.scala | 39 ++--- .../datasources/json/JsonSuite.scala | 4 +- 10 files changed, 140 insertions(+), 188 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 94b0561529e71..90bda2a72ad82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -261,9 +260,7 @@ trait CheckAnalysis extends PredicateHelper { // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns - val widerType = TypeCoercion.findWiderTypeForTwo( - dt1.asNullable, dt2.asNullable, SQLConf.get.caseSensitiveAnalysis) - if (widerType.isEmpty) { + if (TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty) { failAnalysis( s""" |${operator.nodeName} can only be performed on tables with the compatible diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 4eb6e642b1c37..f2df3e132629f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -83,9 +83,7 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with Cas // For each column, traverse all the values and find a common data type and nullability. val fields = table.rows.transpose.zip(table.names).map { case (column, name) => val inputTypes = column.map(_.dataType) - val wideType = TypeCoercion.findWiderTypeWithoutStringPromotion( - inputTypes, conf.caseSensitiveAnalysis) - val tpe = wideType.getOrElse { + val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { table.failAnalysis(s"incompatible types found in column $name for inline table") } StructField(name, tpe, nullable = column.exists(_.nullable)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a7ba201509b78..b2817b0538a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -48,18 +48,18 @@ object TypeCoercion { def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] = InConversion(conf) :: - WidenSetOperationTypes(conf) :: + WidenSetOperationTypes :: PromoteStrings(conf) :: DecimalPrecision :: BooleanEquality :: - FunctionArgumentConversion(conf) :: + FunctionArgumentConversion :: ConcatCoercion(conf) :: EltCoercion(conf) :: - CaseWhenCoercion(conf) :: - IfCoercion(conf) :: + CaseWhenCoercion :: + IfCoercion :: StackCoercion :: Division :: - ImplicitTypeCasts(conf) :: + new ImplicitTypeCasts(conf) :: DateTimeOperations :: WindowFrameCoercion :: Nil @@ -83,10 +83,7 @@ object TypeCoercion { * with primitive types, because in that case the precision and scale of the result depends on * the operation. Those rules are implemented in [[DecimalPrecision]]. */ - def findTightestCommonType( - left: DataType, - right: DataType, - caseSensitive: Boolean): Option[DataType] = (left, right) match { + val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) @@ -105,32 +102,22 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) - case (t1 @ StructType(fields1), t2 @ StructType(fields2)) => - val isSameType = if (caseSensitive) { - DataType.equalsIgnoreNullability(t1, t2) - } else { - DataType.equalsIgnoreCaseAndNullability(t1, t2) - } - - if (isSameType) { - Some(StructType(fields1.zip(fields2).map { case (f1, f2) => - // Since t1 is same type of t2, two StructTypes have the same DataType - // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. - // - Different names: use f1.name - // - Different nullabilities: `nullable` is true iff one of them is nullable. - val dataType = findTightestCommonType(f1.dataType, f2.dataType, caseSensitive).get - StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) - })) - } else { - None - } + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) case (a1 @ ArrayType(et1, hasNull1), a2 @ ArrayType(et2, hasNull2)) if a1.sameType(a2) => - findTightestCommonType(et1, et2, caseSensitive).map(ArrayType(_, hasNull1 || hasNull2)) + findTightestCommonType(et1, et2).map(ArrayType(_, hasNull1 || hasNull2)) case (m1 @ MapType(kt1, vt1, hasNull1), m2 @ MapType(kt2, vt2, hasNull2)) if m1.sameType(m2) => - val keyType = findTightestCommonType(kt1, kt2, caseSensitive) - val valueType = findTightestCommonType(vt1, vt2, caseSensitive) + val keyType = findTightestCommonType(kt1, kt2) + val valueType = findTightestCommonType(vt1, vt2) Some(MapType(keyType.get, valueType.get, hasNull1 || hasNull2)) case _ => None @@ -185,14 +172,13 @@ object TypeCoercion { * i.e. the main difference with [[findTightestCommonType]] is that here we allow some * loss of precision when widening decimal and double, and promotion to string. */ - def findWiderTypeForTwo(t1: DataType, t2: DataType, caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse(stringPromotion(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeForTwo(et1, et2, caseSensitive) - .map(ArrayType(_, containsNull1 || containsNull2)) + findWiderTypeForTwo(et1, et2).map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } @@ -207,8 +193,7 @@ object TypeCoercion { case _ => false } - private def findWiderCommonType( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + private def findWiderCommonType(types: Seq[DataType]): Option[DataType] = { // findWiderTypeForTwo doesn't satisfy the associative law, i.e. (a op b) op c may not equal // to a op (b op c). This is only a problem for StringType or nested StringType in ArrayType. // Excluding these types, findWiderTypeForTwo satisfies the associative law. For instance, @@ -216,7 +201,7 @@ object TypeCoercion { val (stringTypes, nonStringTypes) = types.partition(hasStringType(_)) (stringTypes.distinct ++ nonStringTypes).foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeForTwo(d, c) case _ => None }) } @@ -228,22 +213,20 @@ object TypeCoercion { */ private[analysis] def findWiderTypeWithoutStringPromotionForTwo( t1: DataType, - t2: DataType, - caseSensitive: Boolean): Option[DataType] = { - findTightestCommonType(t1, t2, caseSensitive) + t2: DataType): Option[DataType] = { + findTightestCommonType(t1, t2) .orElse(findWiderTypeForDecimal(t1, t2)) .orElse((t1, t2) match { case (ArrayType(et1, containsNull1), ArrayType(et2, containsNull2)) => - findWiderTypeWithoutStringPromotionForTwo(et1, et2, caseSensitive) + findWiderTypeWithoutStringPromotionForTwo(et1, et2) .map(ArrayType(_, containsNull1 || containsNull2)) case _ => None }) } - def findWiderTypeWithoutStringPromotion( - types: Seq[DataType], caseSensitive: Boolean): Option[DataType] = { + def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c, caseSensitive) + case Some(d) => findWiderTypeWithoutStringPromotionForTwo(d, c) case None => None }) } @@ -296,32 +279,29 @@ object TypeCoercion { * * This rule is only applied to Union/Except/Intersect */ - case class WidenSetOperationTypes(conf: SQLConf) extends Rule[LogicalPlan] { + object WidenSetOperationTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case s @ SetOperation(left, right) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(left :: right :: Nil, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) assert(newChildren.length == 2) s.makeCopy(Array(newChildren.head, newChildren.last)) case s: Union if s.childrenResolved && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren: Seq[LogicalPlan] = - buildNewChildrenWithWiderTypes(s.children, conf.caseSensitiveAnalysis) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) s.makeCopy(Array(newChildren)) } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes( - children: Seq[LogicalPlan], caseSensitive: Boolean): Seq[LogicalPlan] = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute // in all the children val targetTypes: Seq[DataType] = - getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType](), caseSensitive) + getWidestTypes(children, attrIndex = 0, mutable.Queue[DataType]()) if (targetTypes.nonEmpty) { // Add an extra Project if the targetTypes are different from the original types. @@ -336,19 +316,18 @@ object TypeCoercion { @tailrec private def getWidestTypes( children: Seq[LogicalPlan], attrIndex: Int, - castedTypes: mutable.Queue[DataType], - caseSensitive: Boolean): Seq[DataType] = { + castedTypes: mutable.Queue[DataType]): Seq[DataType] = { // Return the result after the widen data types have been found for all the children if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType), caseSensitive) match { + findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column case Some(widenType) => castedTypes.enqueue(widenType) - getWidestTypes(children, attrIndex + 1, castedTypes, caseSensitive) + getWidestTypes(children, attrIndex + 1, castedTypes) } } @@ -453,7 +432,7 @@ object TypeCoercion { val commonTypes = lhs.zip(rhs).flatMap { case (l, r) => findCommonTypeForBinaryComparison(l.dataType, r.dataType, conf) - .orElse(findTightestCommonType(l.dataType, r.dataType, conf.caseSensitiveAnalysis)) + .orElse(findTightestCommonType(l.dataType, r.dataType)) } // The number of columns/expressions must match between LHS and RHS of an @@ -482,7 +461,7 @@ object TypeCoercion { } case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType), conf.caseSensitiveAnalysis) match { + findWiderCommonType(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -536,7 +515,7 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - case class FunctionArgumentConversion(conf: SQLConf) extends TypeCoercionRule { + object FunctionArgumentConversion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. @@ -544,7 +523,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -552,7 +531,7 @@ object TypeCoercion { case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) && !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType))) case None => c } @@ -563,7 +542,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -573,7 +552,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -601,7 +580,7 @@ object TypeCoercion { // compatible with every child column. case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) - findWiderCommonType(types, conf.caseSensitiveAnalysis) match { + findWiderCommonType(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -611,14 +590,14 @@ object TypeCoercion { // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types, conf.caseSensitiveAnalysis) match { + findWiderTypeWithoutStringPromotion(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -658,11 +637,11 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - case class CaseWhenCoercion(conf: SQLConf) extends TypeCoercionRule { + object CaseWhenCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes, conf.caseSensitiveAnalysis) + val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -689,17 +668,16 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - case class IfCoercion(conf: SQLConf) extends TypeCoercionRule { + object IfCoercion extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findWiderTypeForTwo(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - widestType => - val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) - val newRight = if (right.dataType == widestType) right else Cast(right, widestType) - If(pred, newLeft, newRight) + findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) @@ -798,11 +776,12 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - case class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { + class ImplicitTypeCasts(conf: SQLConf) extends TypeCoercionRule { private def rejectTzInString = conf.getConf(SQLConf.REJECT_TIMEZONE_IN_STRING) - override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + override protected def coerceTypes( + plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -825,18 +804,17 @@ object TypeCoercion { } case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonType(left.dataType, right.dataType, conf.caseSensitiveAnalysis).map { - commonType => - if (b.inputType.acceptsType(commonType)) { - // If the expression accepts the tightest common type, cast to that. - val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) - val newRight = if (right.dataType == commonType) right else Cast(right, commonType) - b.withNewChildren(Seq(newLeft, newRight)) - } else { - // Otherwise, don't do anything with the expression. - b - } - }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. + findTightestCommonType(left.dataType, right.dataType).map { commonType => + if (b.inputType.acceptsType(commonType)) { + // If the expression accepts the tightest common type, cast to that. + val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) + val newRight = if (right.dataType == commonType) right else Cast(right, commonType) + b.withNewChildren(Seq(newLeft, newRight)) + } else { + // Otherwise, don't do anything with the expression. + b + } + }.getOrElse(b) // If there is no applicable conversion, leave expression unchanged. case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty => val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0b1965c438e27..b00edca97cd44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,7 +27,7 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.TaskContext +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit @@ -107,13 +107,7 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException("SQLConf should only be created and accessed on the driver.") - } - confGetter.get()() - } + def get: SQLConf = confGetter.get()() val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1280,6 +1274,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4ee12db9c10ca..0bef11659fc9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -81,7 +81,11 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) + if (SQLConf.get.caseSensitiveAnalysis) { + DataType.equalsIgnoreNullability(this, other) + } else { + DataType.equalsIgnoreCaseAndNullability(this, other) + } /** * Returns the same data type but set all nullability fields are true @@ -214,7 +218,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index f73e045685ee1..0acd3b490447d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -128,17 +128,17 @@ class TypeCoercionSuite extends AnalysisTest { } private def checkWidenType( - widenFunc: (DataType, DataType, Boolean) => Option[DataType], + widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, t2: DataType, expected: Option[DataType], isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2, conf.caseSensitiveAnalysis) + var found = widenFunc(t1, t2) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. if (isSymmetric) { - found = widenFunc(t2, t1, conf.caseSensitiveAnalysis) + found = widenFunc(t2, t1) assert(found == expected, s"Expected $expected as wider common type for $t2 and $t1, found $found") } @@ -524,11 +524,11 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeUnaryExpression(Literal.create(null, NullType)), NumericTypeUnaryExpression(Literal.create(null, DoubleType))) } @@ -536,17 +536,17 @@ class TypeCoercionSuite extends AnalysisTest { test("cast NullType for binary operators") { import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), NumericTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), NumericTypeBinaryOperator(Literal.create(null, DoubleType), Literal.create(null, DoubleType))) } test("coalesce casts") { - val rule = TypeCoercion.FunctionArgumentConversion(conf) + val rule = TypeCoercion.FunctionArgumentConversion val intLit = Literal(1) val longLit = Literal.create(1L) @@ -606,7 +606,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("CreateArray casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -616,7 +616,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal(1.0) :: Literal(1) :: Literal("a") @@ -626,7 +626,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal("a"), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal(1) :: Nil), @@ -634,7 +634,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(1).cast(DecimalType(13, 3)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateArray(Literal.create(null, DecimalType(5, 3)) :: Literal.create(null, DecimalType(22, 10)) :: Literal.create(null, DecimalType(38, 38)) @@ -647,7 +647,7 @@ class TypeCoercionSuite extends AnalysisTest { test("CreateMap casts") { // type coercion for map keys - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -658,7 +658,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal.create(2.0, FloatType), FloatType) :: Literal("b") :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal.create(null, DecimalType(5, 3)) :: Literal("a") :: Literal.create(2.0, FloatType) @@ -670,7 +670,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal("b") :: Nil)) // type coercion for map values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2) @@ -681,7 +681,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal(2) :: Cast(Literal(3.0), StringType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal.create(null, DecimalType(38, 0)) :: Literal(2) @@ -693,7 +693,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(38, 38)).cast(DecimalType(38, 38)) :: Nil)) // type coercion for both map keys and values - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, CreateMap(Literal(1) :: Literal("a") :: Literal(2.0) @@ -708,7 +708,7 @@ class TypeCoercionSuite extends AnalysisTest { test("greatest/least cast") { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -717,7 +717,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -726,7 +726,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal(1.0) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -735,7 +735,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DoubleType) :: Literal(1).cast(DoubleType) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(null, DecimalType(15, 0)) :: Literal.create(null, DecimalType(10, 5)) :: Literal(1) @@ -744,7 +744,7 @@ class TypeCoercionSuite extends AnalysisTest { :: Literal.create(null, DecimalType(10, 5)).cast(DecimalType(20, 5)) :: Literal(1).cast(DecimalType(20, 5)) :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, operator(Literal.create(2L, LongType) :: Literal(1) :: Literal.create(null, DecimalType(10, 5)) @@ -757,25 +757,25 @@ class TypeCoercionSuite extends AnalysisTest { } test("nanvl casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(1.0, DoubleType)), NaNvl(Cast(Literal.create(1.0f, FloatType), DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0f, FloatType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(1.0f, FloatType), DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType)), NaNvl(Literal.create(1.0, DoubleType), Literal.create(1.0, DoubleType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0f, FloatType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0f, FloatType), Cast(Literal.create(null, NullType), FloatType))) - ruleTest(TypeCoercion.FunctionArgumentConversion(conf), + ruleTest(TypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, DoubleType), Literal.create(null, NullType)), NaNvl(Literal.create(1.0, DoubleType), Cast(Literal.create(null, NullType), DoubleType))) } test("type coercion for If") { - val rule = TypeCoercion.IfCoercion(conf) + val rule = TypeCoercion.IfCoercion val intLit = Literal(1) val doubleLit = Literal(1.0) val trueLit = Literal.create(true, BooleanType) @@ -823,20 +823,20 @@ class TypeCoercionSuite extends AnalysisTest { } test("type coercion for CaseKeyWhen") { - ruleTest(TypeCoercion.ImplicitTypeCasts(conf), + ruleTest(new TypeCoercion.ImplicitTypeCasts(conf), CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Literal(1.2))), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)) ) - ruleTest(TypeCoercion.CaseWhenCoercion(conf), + ruleTest(TypeCoercion.CaseWhenCoercion, CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))), CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))), Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))) @@ -1085,7 +1085,7 @@ class TypeCoercionSuite extends AnalysisTest { private val timeZoneResolver = ResolveTimeZone(new SQLConf) private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { - timeZoneResolver(TypeCoercion.WidenSetOperationTypes(conf)(plan)) + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) } test("WidenSetOperationTypes for except and intersect") { @@ -1256,7 +1256,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion(conf), Division) + val rules = Seq(FunctionArgumentConversion, Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1275,7 +1275,7 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion(conf), Division, ImplicitTypeCasts(conf)) + val rules = Seq(FunctionArgumentConversion, Division, new ImplicitTypeCasts(conf)) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,7 +898,6 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { - assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1023,20 +1022,14 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = { - assertOnDriver() - Option(activeThreadSession.get) - } + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = { - assertOnDriver() - Option(defaultSession.get) - } + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1069,14 +1062,6 @@ object SparkSession extends Logging { } } - private def assertOnDriver(): Unit = { - if (Utils.isTesting && TaskContext.get != null) { - // we're accessing it during task execution, fail. - throw new IllegalStateException( - "SparkSession should only be created and accessed on the driver.") - } - } - /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 1edf27619ad7b..f9a24806953e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaUtils @@ -522,8 +521,6 @@ object PartitioningUtils { private val findWiderTypeForPartitionColumn: (DataType, DataType) => DataType = { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => StringType case (DoubleType, LongType) | (LongType, DoubleType) => StringType - case (t1, t2) => - TypeCoercion.findWiderTypeForTwo( - t1, t2, SQLConf.get.caseSensitiveAnalysis).getOrElse(StringType) + case (t1, t2) => TypeCoercion.findWiderTypeForTwo(t1, t2).getOrElse(StringType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index e0424b7478122..a270a6451d5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,7 +44,6 @@ private[sql] object JsonInferSchema { createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - val caseSensitive = SQLConf.get.caseSensitiveAnalysis // perform schema inference on each row and merge afterwards val rootType = json.mapPartitions { iter => @@ -55,7 +53,7 @@ private[sql] object JsonInferSchema { try { Utils.tryWithResource(createParser(factory, row)) { parser => parser.nextToken() - Some(inferField(parser, configOptions, caseSensitive)) + Some(inferField(parser, configOptions)) } } catch { case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { @@ -70,7 +68,7 @@ private[sql] object JsonInferSchema { } } }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode, caseSensitive)) + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st @@ -100,15 +98,14 @@ private[sql] object JsonInferSchema { /** * Infer the type of a json document from the parser's token stream */ - private def inferField( - parser: JsonParser, configOptions: JSONOptions, caseSensitive: Boolean): DataType = { + private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { import com.fasterxml.jackson.core.JsonToken._ parser.getCurrentToken match { case null | VALUE_NULL => NullType case FIELD_NAME => parser.nextToken() - inferField(parser, configOptions, caseSensitive) + inferField(parser, configOptions) case VALUE_STRING if parser.getTextLength < 1 => // Zero length strings and nulls have special handling to deal @@ -125,7 +122,7 @@ private[sql] object JsonInferSchema { while (nextUntil(parser, END_OBJECT)) { builder += StructField( parser.getCurrentName, - inferField(parser, configOptions, caseSensitive), + inferField(parser, configOptions), nullable = true) } val fields: Array[StructField] = builder.result() @@ -140,7 +137,7 @@ private[sql] object JsonInferSchema { var elementType: DataType = NullType while (nextUntil(parser, END_ARRAY)) { elementType = compatibleType( - elementType, inferField(parser, configOptions, caseSensitive), caseSensitive) + elementType, inferField(parser, configOptions)) } ArrayType(elementType) @@ -246,14 +243,13 @@ private[sql] object JsonInferSchema { */ private def compatibleRootType( columnNameOfCorruptRecords: String, - parseMode: ParseMode, - caseSensitive: Boolean): (DataType, DataType) => DataType = { + parseMode: ParseMode): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, parseMode, caseSensitive)(ty1, ty2) + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) // Discard null/empty documents case (struct: StructType, NullType) => struct case (NullType, struct: StructType) => struct @@ -263,7 +259,7 @@ private[sql] object JsonInferSchema { withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) // If we get anything else, we call compatibleType. // Usually, when we reach here, ty1 and ty2 are two StructTypes. - case (ty1, ty2) => compatibleType(ty1, ty2, caseSensitive) + case (ty1, ty2) => compatibleType(ty1, ty2) } private[this] val emptyStructFieldArray = Array.empty[StructField] @@ -271,8 +267,8 @@ private[sql] object JsonInferSchema { /** * Returns the most general data type for two given data types. */ - def compatibleType(t1: DataType, t2: DataType, caseSensitive: Boolean): DataType = { - TypeCoercion.findTightestCommonType(t1, t2, caseSensitive).getOrElse { + def compatibleType(t1: DataType, t2: DataType): DataType = { + TypeCoercion.findTightestCommonType(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough @@ -307,8 +303,7 @@ private[sql] object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType( - fields1(f1Idx).dataType, fields2(f2Idx).dataType, caseSensitive) + val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -331,17 +326,15 @@ private[sql] object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType( - compatibleType(elementType1, elementType2, caseSensitive), - containsNull1 || containsNull2) + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when // the given `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2, caseSensitive) + compatibleType(DecimalType.forType(t1), t2) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2), caseSensitive) + compatibleType(t1, DecimalType.forType(t2)) // strings and every string is a Json object. case (_, _) => StringType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 34d23ee53220d..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -122,10 +122,10 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("Get compatible type") { def checkDataType(t1: DataType, t2: DataType, expected: DataType) { - var actual = compatibleType(t1, t2, conf.caseSensitiveAnalysis) + var actual = compatibleType(t1, t2) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") - actual = compatibleType(t2, t1, conf.caseSensitiveAnalysis) + actual = compatibleType(t2, t1) assert(actual == expected, s"Expected $expected as the most general data type for $t1 and $t2, found $actual") } From 7c7caf8a33d8d542706a775f0689da5a57fd934a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 11 May 2018 16:32:38 +0800 Subject: [PATCH 2/8] support accessing SQLConf inside tasks --- .../org/apache/spark/TaskContextImpl.scala | 2 + .../catalyst/json/CreateJacksonParser.scala | 13 ---- .../spark/sql/internal/ReadOnlySQLConf.scala | 58 ++++++++++++++++ .../apache/spark/sql/internal/SQLConf.scala | 21 +++--- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/execution/SQLExecution.scala | 30 +++++++-- .../execution/basicPhysicalOperators.scala | 2 +- .../datasources/json/JsonDataSource.scala | 10 +-- .../datasources/json/JsonInferSchema.scala | 17 +++-- .../exchange/BroadcastExchangeExec.scala | 2 +- .../datasources/json/JsonSuite.scala | 4 +- .../internal/ExecutorSideSQLConfSuite.scala | 66 +++++++++++++++++++ 12 files changed, 184 insertions(+), 45 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index cccd3ea457ba4..0791fe856ef15 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -178,4 +178,6 @@ private[spark] class TaskContextImpl( private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + // TODO: shall we publish it and define it in `TaskContext`? + private[spark] def getLocalProperties(): Properties = localProperties } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 3e8e6db1dbd22..783f94398954a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -78,17 +78,4 @@ private[sql] object CreateJacksonParser extends Serializable { def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = { jsonFactory.createParser(new InputStreamReader(is, enc)) } - - def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = { - val ba = row.getBinary(0) - - jsonFactory.createParser(ba, 0, ba.length) - } - - def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = { - val binary = row.getBinary(0) - val sd = getStreamDecoder(enc, binary, binary.length) - - jsonFactory.createParser(sd) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala new file mode 100644 index 0000000000000..794fa7b270213 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.util.{Map => JMap} + +import org.apache.spark.{TaskContext, TaskContextImpl} +import org.apache.spark.internal.config.{ConfigEntry, ConfigProvider, ConfigReader} + +/** + * A readonly SQLConf that will be created by tasks running at the executor side. It reads the + * configs from the local properties which are propagated from driver to executors. + */ +class ReadOnlySQLConf(context: TaskContext) extends SQLConf { + + @transient override val settings: JMap[String, String] = { + context.asInstanceOf[TaskContextImpl].getLocalProperties().asInstanceOf[JMap[String, String]] + } + + @transient override protected val reader: ConfigReader = { + new ConfigReader(new TaskContextConfigProvider(context)) + } + + override protected def setConfWithCheck(key: String, value: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(key: String): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def unsetConf(entry: ConfigEntry[_]): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } + + override def clear(): Unit = { + throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") + } +} + +class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { + override def get(key: String): Option[String] = Option(context.getLocalProperty(key)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b00edca97cd44..30d9eda53756c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,13 +27,12 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -107,7 +106,13 @@ object SQLConf { * run tests in parallel. At the time this feature was implemented, this was a no-op since we * run unit tests (that does not involve SparkSession) in serial order. */ - def get: SQLConf = confGetter.get()() + def get: SQLConf = { + if (TaskContext.get != null) { + new ReadOnlySQLConf(TaskContext.get()) + } else { + confGetter.get()() + } + } val OPTIMIZER_MAX_ITERATIONS = buildConf("spark.sql.optimizer.maxIterations") .internal() @@ -1274,17 +1279,11 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ - if (Utils.isTesting && SparkEnv.get != null) { - // assert that we're only accessing it on the driver. - assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, - "SQLConf should only be created and accessed on the driver.") - } - /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) - @transient private val reader = new ConfigReader(settings) + @transient protected val reader = new ConfigReader(settings) /** ************************ Spark SQL Params/Hints ******************* */ @@ -1734,7 +1733,7 @@ class SQLConf extends Serializable with Logging { settings.containsKey(key) } - private def setConfWithCheck(key: String, value: String): Unit = { + protected def setConfWithCheck(key: String, value: String): Unit = { settings.put(key, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index d518e07bfb62c..90b2ee40f2af3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1607,7 +1607,9 @@ class Dataset[T] private[sql]( */ @Experimental @InterfaceStability.Evolving - def reduce(func: (T, T) => T): T = rdd.reduce(func) + def reduce(func: (T, T) => T): T = withNewRDDExecutionId { + rdd.reduce(func) + } /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 2c5102b1e5ee7..5585c7dad7c0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -68,16 +68,27 @@ object SQLExecution { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on // streaming queries would give us call site like "run at :0" - val callSite = sparkSession.sparkContext.getCallSite() + val callSite = sc.getCallSite() - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionStart( + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + allConfigs.foreach { + // Excludes external configs defined by users. + case (key, value) if key.startsWith("spark") => sc.setLocalProperty(key, value) + } + + sc.listenerBus.post(SparkListenerSQLExecutionStart( executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) try { body } finally { - sparkSession.sparkContext.listenerBus.post(SparkListenerSQLExecutionEnd( + sc.listenerBus.post(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) + allConfigs.foreach { + case (key, _) => sc.setLocalProperty(key, null) + } } } finally { executionIdToQueryExecution.remove(executionId) @@ -90,12 +101,23 @@ object SQLExecution { * thread from the original one, this method can be used to connect the Spark jobs in this action * with the known executionId, e.g., `BroadcastExchangeExec.relationFuture`. */ - def withExecutionId[T](sc: SparkContext, executionId: String)(body: => T): T = { + def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { + val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + // Set all the specified SQL configs to local properties, so that they can be available at + // the executor side. + val allConfigs = sparkSession.sessionState.conf.getAllConfs + allConfigs.foreach { + // Excludes external configs defined by users. + case (key, value) if key.startsWith("spark") => sc.setLocalProperty(key, value) + } try { sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { + allConfigs.foreach { + case (key, _) => sc.setLocalProperty(key, null) + } sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd8..d54bfbfc14f5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -629,7 +629,7 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { val beforeCollect = System.nanoTime() // Note that we use .executeCollect() because we don't want to convert data to Scala types val rows: Array[InternalRow] = child.executeCollect() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index ba83df0efebd0..66ed9c9afb418 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -99,12 +99,7 @@ object TextInputJsonDataSource extends JsonDataSource { def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd - val rowParser = parsedOptions.encoding.map { enc => - CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) - }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) - - JsonInferSchema.infer(rdd, parsedOptions, rowParser) + JsonInferSchema.infer(sampled, parsedOptions, CreateJacksonParser.string) } private def createBaseDataset( @@ -165,7 +160,8 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + JsonInferSchema.infer[PortableDataStream]( + sparkSession.createDataset(sampled)(Encoders.javaSerialization), parsedOptions, parser) } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index a270a6451d5dd..3bfe4a31034d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -22,7 +22,7 @@ import java.util.Comparator import com.fasterxml.jackson.core._ import org.apache.spark.SparkException -import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Dataset, Encoders} import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions @@ -39,14 +39,14 @@ private[sql] object JsonInferSchema { * 3. Replace any remaining null fields with string, the top type */ def infer[T]( - json: RDD[T], + json: Dataset[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord // perform schema inference on each row and merge afterwards - val rootType = json.mapPartitions { iter => + val inferredTypes = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => @@ -67,8 +67,15 @@ private[sql] object JsonInferSchema { } } } - }.fold(StructType(Nil))( - compatibleRootType(columnNameOfCorruptRecord, parseMode)) + }(Encoders.javaSerialization) + + // TODO: use `Dataset.fold` once we have it. + val rootType = try { + inferredTypes.reduce(compatibleRootType(columnNameOfCorruptRecord, parseMode)) + } catch { + case e: UnsupportedOperationException if e.getMessage == "empty collection" => + StructType(Nil) + } canonicalizeType(rootType) match { case Some(st: StructType) => st diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index daea6c39624d6..9e0ec9481b0de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -69,7 +69,7 @@ case class BroadcastExchangeExec( Future { // This will run in another thread. Set the execution id so that we can connect these jobs // with the correct execution. - SQLExecution.withExecutionId(sparkContext, executionId) { + SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { try { val beforeCollect = System.nanoTime() // Use executeCollect/executeCollectIterator to avoid conversion to Scala types diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 4b3921c61a000..db7d3ef1abe7c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1374,7 +1374,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonInferSchema.infer on empty RDD") { // This is really a test that it doesn't throw an exception val emptySchema = JsonInferSchema.infer( - empty.rdd, + empty, new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) @@ -1401,7 +1401,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-8093 Erase empty structs") { val emptySchema = JsonInferSchema.infer( - emptyRecords.rdd, + emptyRecords, new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala new file mode 100644 index 0000000000000..404d6313ab92c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SQLTestUtils + +class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { + import testImplicits._ + + protected var spark: SparkSession = null + + // Create a new [[SparkSession]] running in local-cluster mode. + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + spark.stop() + spark = null + } + + test("ReadonlySQLConf is correctly created at the executor side") { + SQLConf.get.setConfString("spark.sql.x", "a") + try { + val checks = spark.range(10).mapPartitions { it => + val conf = SQLConf.get + Iterator(conf.isInstanceOf[ReadOnlySQLConf] && conf.getConfString("spark.sql.x") == "a") + }.collect() + assert(checks.forall(_ == true)) + } finally { + SQLConf.get.unsetConf("spark.sql.x") + } + } + + test("case-sensitive config should work for json schema inference") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + withTempPath { path => + val pathString = path.getCanonicalPath + spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).write.mode("append").json(pathString) + assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) + } + } + } +} From dc63dbef21d776f09a4e3be6aba1f1882e59d424 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 12 May 2018 00:44:35 +0800 Subject: [PATCH 3/8] address comments --- .../org/apache/spark/sql/internal/ReadOnlySQLConf.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala index 794fa7b270213..19f67236c8979 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/ReadOnlySQLConf.scala @@ -51,6 +51,14 @@ class ReadOnlySQLConf(context: TaskContext) extends SQLConf { override def clear(): Unit = { throw new UnsupportedOperationException("Cannot mutate ReadOnlySQLConf.") } + + override def clone(): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } + + override def copy(entries: (ConfigEntry[_], Any)*): SQLConf = { + throw new UnsupportedOperationException("Cannot clone/copy ReadOnlySQLConf.") + } } class TaskContextConfigProvider(context: TaskContext) extends ConfigProvider { From 2ecabe4fd984bb6a3f909364dcee27490c7a5d0a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 12 May 2018 02:49:32 +0800 Subject: [PATCH 4/8] fix json --- .../datasources/json/JsonDataSource.scala | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 66ed9c9afb418..d67f8027a2ea6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -93,13 +93,21 @@ object TextInputJsonDataSource extends JsonDataSource { inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) - inferFromDataset(json, parsedOptions) } def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - JsonInferSchema.infer(sampled, parsedOptions, CreateJacksonParser.string) + if (parsedOptions.encoding.isDefined) { + // TODO: We should be able to parse the input string directly. Remove this hack when we + // support setting encoding when reading text files. + val encoding = parsedOptions.encoding.get + val textDS = sampled.map(new Text(_))(Encoders.javaSerialization) + val parser = CreateJacksonParser.text(encoding, _: JsonFactory, _: Text) + JsonInferSchema.infer(textDS, parsedOptions, parser) + } else { + JsonInferSchema.infer(sampled, parsedOptions, CreateJacksonParser.string) + } } private def createBaseDataset( @@ -107,10 +115,6 @@ object TextInputJsonDataSource extends JsonDataSource { inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { val paths = inputPaths.map(_.getPath.toString) - val textOptions = Map.empty[String, String] ++ - parsedOptions.encoding.map("encoding" -> _) ++ - parsedOptions.lineSeparator.map("lineSep" -> _) - sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, From bf8b42d494d4a8f21bd08b2fd6ed531e21e4eb49 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 13 May 2018 11:27:55 +0800 Subject: [PATCH 5/8] fix --- .../catalyst/json/CreateJacksonParser.scala | 13 ++++++ .../spark/sql/execution/SQLExecution.scala | 46 +++++++++---------- .../datasources/json/JsonDataSource.scala | 26 +++++------ .../datasources/json/JsonInferSchema.scala | 17 ++----- .../datasources/json/JsonSuite.scala | 4 +- 5 files changed, 56 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index 783f94398954a..3e8e6db1dbd22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -78,4 +78,17 @@ private[sql] object CreateJacksonParser extends Serializable { def inputStream(enc: String, jsonFactory: JsonFactory, is: InputStream): JsonParser = { jsonFactory.createParser(new InputStreamReader(is, enc)) } + + def internalRow(jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val ba = row.getBinary(0) + + jsonFactory.createParser(ba, 0, ba.length) + } + + def internalRow(enc: String, jsonFactory: JsonFactory, row: InternalRow): JsonParser = { + val binary = row.getBinary(0) + val sd = getStreamDecoder(enc, binary, binary.length) + + jsonFactory.createParser(sd) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 5585c7dad7c0f..5c5da41312d8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -70,24 +70,15 @@ object SQLExecution { // streaming queries would give us call site like "run at :0" val callSite = sc.getCallSite() - // Set all the specified SQL configs to local properties, so that they can be available at - // the executor side. - val allConfigs = sparkSession.sessionState.conf.getAllConfs - allConfigs.foreach { - // Excludes external configs defined by users. - case (key, value) if key.startsWith("spark") => sc.setLocalProperty(key, value) - } - - sc.listenerBus.post(SparkListenerSQLExecutionStart( - executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) - try { - body - } finally { - sc.listenerBus.post(SparkListenerSQLExecutionEnd( - executionId, System.currentTimeMillis())) - allConfigs.foreach { - case (key, _) => sc.setLocalProperty(key, null) + withSQLConfPropagated(sparkSession) { + sc.listenerBus.post(SparkListenerSQLExecutionStart( + executionId, callSite.shortForm, callSite.longForm, queryExecution.toString, + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), System.currentTimeMillis())) + try { + body + } finally { + sc.listenerBus.post(SparkListenerSQLExecutionEnd( + executionId, System.currentTimeMillis())) } } } finally { @@ -104,21 +95,30 @@ object SQLExecution { def withExecutionId[T](sparkSession: SparkSession, executionId: String)(body: => T): T = { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + withSQLConfPropagated(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } + } + } + + def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { // Set all the specified SQL configs to local properties, so that they can be available at // the executor side. val allConfigs = sparkSession.sessionState.conf.getAllConfs - allConfigs.foreach { + for ((key, value) <- allConfigs) { // Excludes external configs defined by users. - case (key, value) if key.startsWith("spark") => sc.setLocalProperty(key, value) + if (key.startsWith("spark")) sparkSession.sparkContext.setLocalProperty(key, value) } try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) body } finally { allConfigs.foreach { - case (key, _) => sc.setLocalProperty(key, null) + case (key, _) => sparkSession.sparkContext.setLocalProperty(key, null) } - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index d67f8027a2ea6..3b6df45e949e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -93,20 +94,19 @@ object TextInputJsonDataSource extends JsonDataSource { inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): StructType = { val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) + inferFromDataset(json, parsedOptions) } def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = { val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions) - if (parsedOptions.encoding.isDefined) { - // TODO: We should be able to parse the input string directly. Remove this hack when we - // support setting encoding when reading text files. - val encoding = parsedOptions.encoding.get - val textDS = sampled.map(new Text(_))(Encoders.javaSerialization) - val parser = CreateJacksonParser.text(encoding, _: JsonFactory, _: Text) - JsonInferSchema.infer(textDS, parsedOptions, parser) - } else { - JsonInferSchema.infer(sampled, parsedOptions, CreateJacksonParser.string) + val rdd: RDD[InternalRow] = sampled.queryExecution.toRdd + val rowParser = parsedOptions.encoding.map { enc => + CreateJacksonParser.internalRow(enc, _: JsonFactory, _: InternalRow) + }.getOrElse(CreateJacksonParser.internalRow(_: JsonFactory, _: InternalRow)) + + SQLExecution.withSQLConfPropagated(json.sparkSession) { + JsonInferSchema.infer(rdd, parsedOptions, rowParser) } } @@ -114,11 +114,10 @@ object TextInputJsonDataSource extends JsonDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: JSONOptions): Dataset[String] = { - val paths = inputPaths.map(_.getPath.toString) sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, - paths = paths, + paths = inputPaths.map(_.getPath.toString), className = classOf[TextFileFormat].getName, options = parsedOptions.parameters ).resolveRelation(checkFilesExist = false)) @@ -164,8 +163,9 @@ object MultiLineJsonDataSource extends JsonDataSource { .map(enc => createParser(enc, _: JsonFactory, _: PortableDataStream)) .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) - JsonInferSchema.infer[PortableDataStream]( - sparkSession.createDataset(sampled)(Encoders.javaSerialization), parsedOptions, parser) + SQLExecution.withSQLConfPropagated(sparkSession) { + JsonInferSchema.infer[PortableDataStream](sampled, parsedOptions, parser) + } } private def createBaseRdd( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index 3bfe4a31034d8..a270a6451d5dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -22,7 +22,7 @@ import java.util.Comparator import com.fasterxml.jackson.core._ import org.apache.spark.SparkException -import org.apache.spark.sql.{Dataset, Encoders} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil import org.apache.spark.sql.catalyst.json.JSONOptions @@ -39,14 +39,14 @@ private[sql] object JsonInferSchema { * 3. Replace any remaining null fields with string, the top type */ def infer[T]( - json: Dataset[T], + json: RDD[T], configOptions: JSONOptions, createParser: (JsonFactory, T) => JsonParser): StructType = { val parseMode = configOptions.parseMode val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord // perform schema inference on each row and merge afterwards - val inferredTypes = json.mapPartitions { iter => + val rootType = json.mapPartitions { iter => val factory = new JsonFactory() configOptions.setJacksonOptions(factory) iter.flatMap { row => @@ -67,15 +67,8 @@ private[sql] object JsonInferSchema { } } } - }(Encoders.javaSerialization) - - // TODO: use `Dataset.fold` once we have it. - val rootType = try { - inferredTypes.reduce(compatibleRootType(columnNameOfCorruptRecord, parseMode)) - } catch { - case e: UnsupportedOperationException if e.getMessage == "empty collection" => - StructType(Nil) - } + }.fold(StructType(Nil))( + compatibleRootType(columnNameOfCorruptRecord, parseMode)) canonicalizeType(rootType) match { case Some(st: StructType) => st diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index db7d3ef1abe7c..4b3921c61a000 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1374,7 +1374,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonInferSchema.infer on empty RDD") { // This is really a test that it doesn't throw an exception val emptySchema = JsonInferSchema.infer( - empty, + empty.rdd, new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) @@ -1401,7 +1401,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-8093 Erase empty structs") { val emptySchema = JsonInferSchema.infer( - emptyRecords, + emptyRecords.rdd, new JSONOptions(Map.empty[String, String], "GMT"), CreateJacksonParser.string) assert(StructType(Seq()) === emptySchema) From a100dea9573e9b43b993516c817e306d80f72d29 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 14 May 2018 13:57:40 +0800 Subject: [PATCH 6/8] address comments --- .../org/apache/spark/sql/SparkSession.scala | 21 ++++++++++++++++--- .../spark/sql/execution/SQLExecution.scala | 13 ++++++++++-- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c502e583a55c5..e2a1a57c7dd4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal -import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext} +import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -898,6 +898,7 @@ object SparkSession extends Logging { * @since 2.0.0 */ def getOrCreate(): SparkSession = synchronized { + assertOnDriver() // Get the session from current thread's active session. var session = activeThreadSession.get() if ((session ne null) && !session.sparkContext.isStopped) { @@ -1022,14 +1023,20 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get) + def getActiveSession: Option[SparkSession] = { + assertOnDriver() + Option(activeThreadSession.get) + } /** * Returns the default SparkSession that is returned by the builder. * * @since 2.2.0 */ - def getDefaultSession: Option[SparkSession] = Option(defaultSession.get) + def getDefaultSession: Option[SparkSession] = { + assertOnDriver() + Option(defaultSession.get) + } /** * Returns the currently active SparkSession, otherwise the default one. If there is no default @@ -1062,6 +1069,14 @@ object SparkSession extends Logging { } } + private def assertOnDriver(): Unit = { + if (Utils.isTesting && TaskContext.get != null) { + // we're accessing it during task execution, fail. + throw new IllegalStateException( + "SparkSession should only be created and accessed on the driver.") + } + } + /** * Helper method to create an instance of `SessionState` based on `className` from conf. * The result is either `SessionState` or a Hive based `SessionState`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 5c5da41312d8f..6695248abdda2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -109,15 +109,24 @@ object SQLExecution { // Set all the specified SQL configs to local properties, so that they can be available at // the executor side. val allConfigs = sparkSession.sessionState.conf.getAllConfs + val originalLocalProps = scala.collection.mutable.HashMap.empty[String, String] for ((key, value) <- allConfigs) { // Excludes external configs defined by users. - if (key.startsWith("spark")) sparkSession.sparkContext.setLocalProperty(key, value) + if (key.startsWith("spark")) { + Option(sparkSession.sparkContext.getLocalProperty(key)).foreach { + // If users already set a value in local properties, keep it and restore it at the end. + origin => originalLocalProps(key) = origin + } + sparkSession.sparkContext.setLocalProperty(key, value) + } } try { body } finally { allConfigs.foreach { - case (key, _) => sparkSession.sparkContext.setLocalProperty(key, null) + case (key, _) => + val origin = originalLocalProps.getOrElse(key, null) + sparkSession.sparkContext.setLocalProperty(key, origin) } } } From 01e288a332c25dcb9cd5af8c818ae7018dd6a9bb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 14 May 2018 21:11:40 +0800 Subject: [PATCH 7/8] fix --- .../main/scala/org/apache/spark/sql/SparkSession.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index e2a1a57c7dd4d..11d5bb52ec809 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -90,11 +90,6 @@ class SparkSession private( sparkContext.assertNotStopped() - // If there is no active SparkSession, uses the default SQL conf. Otherwise, use the session's. - SQLConf.setSQLConfGetter(() => { - SparkSession.getActiveSession.map(_.sessionState.conf).getOrElse(SQLConf.getFallbackConf) - }) - /** * The version of Spark on which this application is running. * @@ -987,6 +982,9 @@ object SparkSession extends Logging { * @since 2.0.0 */ def setActiveSession(session: SparkSession): Unit = { + SQLConf.setSQLConfGetter(() => { + session.sessionState.conf + }) activeThreadSession.set(session) } From cfb76f59a71c04048c14892beba442753e72410f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 18 May 2018 17:25:26 +0800 Subject: [PATCH 8/8] address comment --- .../spark/sql/execution/SQLExecution.scala | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 6695248abdda2..032525a08ccdb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -106,27 +106,22 @@ object SQLExecution { } def withSQLConfPropagated[T](sparkSession: SparkSession)(body: => T): T = { + val sc = sparkSession.sparkContext // Set all the specified SQL configs to local properties, so that they can be available at // the executor side. val allConfigs = sparkSession.sessionState.conf.getAllConfs - val originalLocalProps = scala.collection.mutable.HashMap.empty[String, String] - for ((key, value) <- allConfigs) { - // Excludes external configs defined by users. - if (key.startsWith("spark")) { - Option(sparkSession.sparkContext.getLocalProperty(key)).foreach { - // If users already set a value in local properties, keep it and restore it at the end. - origin => originalLocalProps(key) = origin - } - sparkSession.sparkContext.setLocalProperty(key, value) - } + val originalLocalProps = allConfigs.collect { + case (key, value) if key.startsWith("spark") => + val originalValue = sc.getLocalProperty(key) + sc.setLocalProperty(key, value) + (key, originalValue) } + try { body } finally { - allConfigs.foreach { - case (key, _) => - val origin = originalLocalProps.getOrElse(key, null) - sparkSession.sparkContext.setLocalProperty(key, origin) + for ((key, value) <- originalLocalProps) { + sc.setLocalProperty(key, value) } } }