Skip to content

Commit

Permalink
[SPARK-30098][SQL] Add a configuration to use default datasource as p…
Browse files Browse the repository at this point in the history
…rovider for CREATE TABLE command

### What changes were proposed in this pull request?

For CRETE TABLE [AS SELECT] command, creates native Parquet table if neither USING nor STORE AS is specified and `spark.sql.legacy.createHiveTableByDefault` is false.

This is a retry after we unify the CREATE TABLE syntax. It partially reverts apache/spark@d2bec5e

This PR allows `CREATE EXTERNAL TABLE` when `LOCATION` is present. This was not allowed for data source tables before, which is an unnecessary behavior different with hive tables.

### Why are the changes needed?

Changing from Hive text table to native Parquet table has many benefits:
1. be consistent with `DataFrameWriter.saveAsTable`.
2. better performance
3. better support for nested types (Hive text table doesn't work well with nested types, e.g. `insert into t values struct(null)` actually inserts a null value not `struct(null)` if `t` is a Hive text table, which leads to wrong result)
4. better interoperability as Parquet is a more popular open file format.

### Does this PR introduce _any_ user-facing change?

No by default. If the config is set, the behavior change is described below:

Behavior-wise, the change is very small as the native Parquet table is also Hive-compatible. All the Spark DDL commands that works for hive tables also works for native Parquet tables, with two exceptions: `ALTER TABLE SET [SERDE | SERDEPROPERTIES]` and `LOAD DATA`.

char/varchar behavior has been taken care by apache/spark#30412, and there is no behavior difference between data source and hive tables.

One potential issue is `CREATE TABLE ... LOCATION ...` while users want to directly access the files later. It's more like a corner case and the legacy config should be good enough.

Another potential issue is users may use Spark to create the table and then use Hive to add partitions with different serde. This is not allowed for Spark native tables.

### How was this patch tested?

Re-enable the tests

Closes #30554 from cloud-fan/create-table.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
a0x8o committed Dec 3, 2020
1 parent 8cef24b commit 6b800e6
Show file tree
Hide file tree
Showing 49 changed files with 1,179 additions and 268 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite
conf.set(MEMORY_STORAGE_FRACTION, 0.999)
conf.set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L)

// to make a replication attempt to inactive store fail fast
conf.set("spark.core.connection.ack.wait.timeout", "1s")
// to make cached peers refresh frequently
conf.set(STORAGE_CACHED_PEERS_TTL, 10)

Expand Down
1 change: 1 addition & 0 deletions dev/.rat-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ SessionHandler.java
GangliaReporter.java
application_1578436911597_0052
config.properties
local-1596020211915
app-20200706201101-0003
py.typed
_metadata
Expand Down
11 changes: 0 additions & 11 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,6 @@ Apart from these, the following properties are also available, and may be useful
<td>120s</td>
<td>
Default timeout for all network interactions. This config will be used in place of
<code>spark.core.connection.ack.wait.timeout</code>,
<code>spark.storage.blockManagerHeartbeatTimeoutMs</code>,
<code>spark.shuffle.io.connectionTimeout</code>, <code>spark.rpc.askTimeout</code> or
<code>spark.rpc.lookupTimeout</code> if they are not configured.
Expand Down Expand Up @@ -1982,16 +1981,6 @@ Apart from these, the following properties are also available, and may be useful
</td>
<td>1.4.0</td>
</tr>
<tr>
<td><code>spark.core.connection.ack.wait.timeout</code></td>
<td><code>spark.network.timeout</code></td>
<td>
How long for the connection to wait for ack to occur before timing
out and giving up. To avoid unwilling timeout caused by long pause like GC,
you can set larger value.
</td>
<td>1.1.1</td>
</tr>
<tr>
<td><code>spark.network.maxRemoteBlockSizeFetchToMem</code></td>
<td>200m</td>
Expand Down
29 changes: 17 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,20 +254,25 @@ class ImputerModel private[ml] (
/** @group setParam */
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

@transient private lazy val surrogates = {
val row = surrogateDF.head()
row.schema.fieldNames.zipWithIndex
.map { case (name, index) => (name, row.getDouble(index)) }
.toMap
}

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
val (inputColumns, outputColumns) = getInOutCols
val surrogates = surrogateDF.select(inputColumns.map(col): _*).head().toSeq


val newCols = inputColumns.zip(outputColumns).zip(surrogates).map {
case ((inputCol, outputCol), surrogate) =>
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol).cast(DoubleType)
when(ic.isNull, surrogate)
.when(ic === $(missingValue), surrogate)
.otherwise(ic)
.cast(inputType)
val (inputColumns, outputColumns) = getInOutCols()

val newCols = inputColumns.map { inputCol =>
val surrogate = surrogates(inputCol)
val inputType = dataset.schema(inputCol).dataType
val ic = col(inputCol).cast(DoubleType)
when(ic.isNull, surrogate)
.when(ic === $(missingValue), surrogate)
.otherwise(ic)
.cast(inputType)
}
dataset.withColumns(outputColumns, newCols).toDF()
}
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -3852,9 +3852,21 @@ def from_arrays_of_labels(cls, arrayOfLabels, inputCols, outputCols=None,
def labels(self):
"""
Ordered list of labels, corresponding to indices to be assigned.
.. deprecated:: 3.1.0
It will be removed in future versions. Use `labelsArray` method instead.
"""
return self._call_java("labels")

@property
@since("3.0.2")
def labelsArray(self):
"""
Array of ordered list of labels, corresponding to indices to be assigned
for each input column.
"""
return self._call_java("labelsArray")


@inherit_doc
class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/ml/tests/test_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def test_string_indexer_from_labels(self):
model = StringIndexerModel.from_labels(["a", "b", "c"], inputCol="label",
outputCol="indexed", handleInvalid="keep")
self.assertEqual(model.labels, ["a", "b", "c"])
self.assertEqual(model.labelsArray, [("a", "b", "c")])

df1 = self.spark.createDataFrame([
(0, "a"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,14 @@ object FunctionRegistry {
expression[MakeInterval]("make_interval"),
expression[DatePart]("date_part"),
expression[Extract]("extract"),
expression[DateFromUnixDate]("date_from_unix_date"),
expression[UnixDate]("unix_date"),
expression[SecondsToTimestamp]("timestamp_seconds"),
expression[MillisToTimestamp]("timestamp_millis"),
expression[MicrosToTimestamp]("timestamp_micros"),
expression[UnixSeconds]("unix_seconds"),
expression[UnixMillis]("unix_millis"),
expression[UnixMicros]("unix_micros"),

// collection functions
expression[CreateArray]("array"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1393,25 +1393,19 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
"""
}

private[this] def lowerAndUpperBound(
fractionType: String,
integralType: String): (String, String) = {
assert(fractionType == "float" || fractionType == "double")
val typeIndicator = fractionType.charAt(0)
val (min, max) = integralType.toLowerCase(Locale.ROOT) match {
case "long" => (Long.MinValue, Long.MaxValue)
case "int" => (Int.MinValue, Int.MaxValue)
case "short" => (Short.MinValue, Short.MaxValue)
case "byte" => (Byte.MinValue, Byte.MaxValue)
private[this] def lowerAndUpperBound(integralType: String): (String, String) = {
val (min, max, typeIndicator) = integralType.toLowerCase(Locale.ROOT) match {
case "long" => (Long.MinValue, Long.MaxValue, "L")
case "int" => (Int.MinValue, Int.MaxValue, "")
case "short" => (Short.MinValue, Short.MaxValue, "")
case "byte" => (Byte.MinValue, Byte.MaxValue, "")
}
(min.toString + typeIndicator, max.toString + typeIndicator)
}

private[this] def castFractionToIntegralTypeCode(
fractionType: String,
integralType: String): CastFunction = {
private[this] def castFractionToIntegralTypeCode(integralType: String): CastFunction = {
assert(ansiEnabled)
val (min, max) = lowerAndUpperBound(fractionType, integralType)
val (min, max) = lowerAndUpperBound(integralType)
val mathClass = classOf[Math].getName
// When casting floating values to integral types, Spark uses the method `Numeric.toInt`
// Or `Numeric.toLong` directly. For positive floating values, it is equivalent to `Math.floor`;
Expand Down Expand Up @@ -1449,12 +1443,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType => castTimestampToIntegralTypeCode(ctx, "byte")
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "byte")
case _: ShortType | _: IntegerType | _: LongType if ansiEnabled =>
case ShortType | IntegerType | LongType if ansiEnabled =>
castIntegralTypeToIntegralTypeExactCode("byte")
case _: FloatType if ansiEnabled =>
castFractionToIntegralTypeCode("float", "byte")
case _: DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("double", "byte")
case FloatType | DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("byte")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (byte) $c;"
}
Expand Down Expand Up @@ -1482,12 +1474,10 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType => castTimestampToIntegralTypeCode(ctx, "short")
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "short")
case _: IntegerType | _: LongType if ansiEnabled =>
case IntegerType | LongType if ansiEnabled =>
castIntegralTypeToIntegralTypeExactCode("short")
case _: FloatType if ansiEnabled =>
castFractionToIntegralTypeCode("float", "short")
case _: DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("double", "short")
case FloatType | DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("short")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (short) $c;"
}
Expand All @@ -1513,11 +1503,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
(c, evPrim, evNull) => code"$evNull = true;"
case TimestampType => castTimestampToIntegralTypeCode(ctx, "int")
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "int")
case _: LongType if ansiEnabled => castIntegralTypeToIntegralTypeExactCode("int")
case _: FloatType if ansiEnabled =>
castFractionToIntegralTypeCode("float", "int")
case _: DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("double", "int")
case LongType if ansiEnabled => castIntegralTypeToIntegralTypeExactCode("int")
case FloatType | DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("int")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (int) $c;"
}
Expand All @@ -1544,10 +1532,8 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit
case TimestampType =>
(c, evPrim, evNull) => code"$evPrim = (long) ${timestampToLongCode(c)};"
case DecimalType() => castDecimalToIntegralTypeCode(ctx, "long")
case _: FloatType if ansiEnabled =>
castFractionToIntegralTypeCode("float", "long")
case _: DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("double", "long")
case FloatType | DoubleType if ansiEnabled =>
castFractionToIntegralTypeCode("long")
case x: NumericType =>
(c, evPrim, evNull) => code"$evPrim = (long) $c;"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,52 @@ case class DayOfYear(child: Expression) extends GetDateField {
override val funcName = "getDayInYear"
}

@ExpressionDescription(
usage = "_FUNC_(days) - Create date from the number of days since 1970-01-01.",
examples = """
Examples:
> SELECT _FUNC_(1);
1970-01-02
""",
group = "datetime_funcs",
since = "3.1.0")
case class DateFromUnixDate(child: Expression) extends UnaryExpression
with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType)

override def dataType: DataType = DateType

override def nullSafeEval(input: Any): Any = input.asInstanceOf[Int]

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c => c)

override def prettyName: String = "date_from_unix_date"
}

@ExpressionDescription(
usage = "_FUNC_(date) - Returns the number of days since 1970-01-01.",
examples = """
Examples:
> SELECT _FUNC_(DATE("1970-01-02"));
1
""",
group = "datetime_funcs",
since = "3.1.0")
case class UnixDate(child: Expression) extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)

override def dataType: DataType = IntegerType

override def nullSafeEval(input: Any): Any = input.asInstanceOf[Int]

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, c => c)

override def prettyName: String = "unix_date"
}

abstract class IntegralToTimestampBase extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {

Expand Down Expand Up @@ -524,6 +570,79 @@ case class MicrosToTimestamp(child: Expression)
override def prettyName: String = "timestamp_micros"
}

abstract class TimestampToLongBase extends UnaryExpression
with ExpectsInputTypes with NullIntolerant {

protected def scaleFactor: Long

override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)

override def dataType: DataType = LongType

override def nullSafeEval(input: Any): Any = {
Math.floorDiv(input.asInstanceOf[Number].longValue(), scaleFactor)
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (scaleFactor == 1) {
defineCodeGen(ctx, ev, c => c)
} else {
defineCodeGen(ctx, ev, c => s"java.lang.Math.floorDiv($c, ${scaleFactor}L)")
}
}
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(timestamp) - Returns the number of seconds since 1970-01-01 00:00:00 UTC. Truncates higher levels of precision.",
examples = """
Examples:
> SELECT _FUNC_(TIMESTAMP('1970-01-01 00:00:01Z'));
1
""",
group = "datetime_funcs",
since = "3.1.0")
// scalastyle:on line.size.limit
case class UnixSeconds(child: Expression) extends TimestampToLongBase {
override def scaleFactor: Long = MICROS_PER_SECOND

override def prettyName: String = "unix_seconds"
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(timestamp) - Returns the number of milliseconds since 1970-01-01 00:00:00 UTC. Truncates higher levels of precision.",
examples = """
Examples:
> SELECT _FUNC_(TIMESTAMP('1970-01-01 00:00:01Z'));
1000
""",
group = "datetime_funcs",
since = "3.1.0")
// scalastyle:on line.size.limit
case class UnixMillis(child: Expression) extends TimestampToLongBase {
override def scaleFactor: Long = MICROS_PER_MILLIS

override def prettyName: String = "unix_millis"
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(timestamp) - Returns the number of microseconds since 1970-01-01 00:00:00 UTC.",
examples = """
Examples:
> SELECT _FUNC_(TIMESTAMP('1970-01-01 00:00:01Z'));
1000000
""",
group = "datetime_funcs",
since = "3.1.0")
// scalastyle:on line.size.limit
case class UnixMicros(child: Expression) extends TimestampToLongBase {
override def scaleFactor: Long = 1L

override def prettyName: String = "unix_micros"
}

@ExpressionDescription(
usage = "_FUNC_(date) - Returns the year component of the date/timestamp.",
examples = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2921,6 +2921,15 @@ object SQLConf {
.stringConf
.createWithDefault("")

val LEGACY_CREATE_HIVE_TABLE_BY_DEFAULT =
buildConf("spark.sql.legacy.createHiveTableByDefault")
.internal()
.doc("When set to true, CREATE TABLE syntax without USING or STORED AS will use Hive " +
s"instead of the value of ${DEFAULT_DATA_SOURCE_NAME.key} as the table provider.")
.version("3.1.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,11 @@ abstract class AnsiCastSuiteBase extends CastSuiteBase {
}
}
}

test("SPARK-26218: Fix the corner case of codegen when casting float to Integer") {
checkExceptionInExpression[ArithmeticException](
cast(cast(Literal("2147483648"), FloatType), IntegerType), "overflow")
}
}

/**
Expand Down
Loading

0 comments on commit 6b800e6

Please sign in to comment.