Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,31 @@ object Cast extends QueryErrorsBase {
*/
def canUpCast(from: DataType, to: DataType): Boolean = UpCastRule.canUpCast(from, to)

/**
* Returns true iff it is safe to provide a default value of `from` type typically defined in the
* data source metadata to the `to` type typically in the read schema of a query.
*/
def canAssignDefaultValue(from: DataType, to: DataType): Boolean = {
def isVariantStruct(st: StructType): Boolean = {
st.fields.length > 0 && st.fields.forall(_.metadata.contains("__VARIANT_METADATA_KEY"))
}
(from, to) match {
case (s1: StructType, s2: StructType) =>
s1.length == s2.length && s1.fields.zip(s2.fields).forall {
case (f1, f2) => resolvableNullability(f1.nullable, f2.nullable) &&
canAssignDefaultValue(f1.dataType, f2.dataType)
}
case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
resolvableNullability(fn, tn) && canAssignDefaultValue(fromType, toType)
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
resolvableNullability(fn, tn) && canAssignDefaultValue(fromKey, toKey) &&
canAssignDefaultValue(fromValue, toValue)
// A VARIANT field can be read as StructType due to shredding.
case (VariantType, s: StructType) => isVariantStruct(s)
case _ => canUpCast(from, to)
}
}

/**
* Returns true iff we can cast the `from` type to `to` type as per the ANSI SQL.
* In practice, the behavior is mostly the same as PostgreSQL. It disallows certain unreasonable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ object ResolveDefaultColumns extends QueryErrorsBase
val ret = analyzed match {
case equivalent if equivalent.dataType == supplanted =>
equivalent
case canUpCast if Cast.canUpCast(canUpCast.dataType, supplanted) =>
case _ if Cast.canAssignDefaultValue(analyzed.dataType, supplanted) =>
Cast(analyzed, supplanted, Some(conf.sessionLocalTimeZone))
case other =>
defaultValueFromWiderTypeLiteral(other, supplanted, colName).getOrElse(
Expand Down
64 changes: 36 additions & 28 deletions sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -197,36 +197,39 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
}

test("round trip tests") {
val rand = new Random(42)
val input = Seq.fill(50) {
if (rand.nextInt(10) == 0) {
null
} else {
val value = new Array[Byte](rand.nextInt(50))
rand.nextBytes(value)
val metadata = new Array[Byte](rand.nextInt(50))
rand.nextBytes(metadata)
// Generate a valid metadata, otherwise the shredded reader will fail.
new VariantVal(value, Array[Byte](VERSION, 0, 0) ++ metadata)
withSQLConf(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false") {
val rand = new Random(42)
val input = Seq.fill(50) {
if (rand.nextInt(10) == 0) {
null
} else {
val value = new Array[Byte](rand.nextInt(50))
rand.nextBytes(value)
val metadata = new Array[Byte](rand.nextInt(50))
rand.nextBytes(metadata)
// Generate a valid metadata, otherwise the shredded reader will fail.
new VariantVal(value, Array[Byte](VERSION, 0, 0) ++ metadata)
}
}
}

val df = spark.createDataFrame(
spark.sparkContext.parallelize(input.map(Row(_))),
StructType.fromDDL("v variant")
)
val result = df.collect().map(_.get(0).asInstanceOf[VariantVal])
val df = spark.createDataFrame(
spark.sparkContext.parallelize(input.map(Row(_))),
StructType.fromDDL("v variant")
)
val result = df.collect().map(_.get(0).asInstanceOf[VariantVal])

def prepareAnswer(values: Seq[VariantVal]): Seq[String] = {
values.map(v => if (v == null) "null" else v.debugString()).sorted
}
assert(prepareAnswer(input) == prepareAnswer(result.toImmutableArraySeq))
def prepareAnswer(values: Seq[VariantVal]): Seq[String] = {
values.map(v => if (v == null) "null" else v.debugString()).sorted
}
assert(prepareAnswer(input) == prepareAnswer(result.toImmutableArraySeq))

withTempDir { dir =>
val tempDir = new File(dir, "files").getCanonicalPath
df.write.parquet(tempDir)
val readResult = spark.read.parquet(tempDir).collect().map(_.get(0).asInstanceOf[VariantVal])
assert(prepareAnswer(input) == prepareAnswer(readResult.toImmutableArraySeq))
withTempDir { dir =>
val tempDir = new File(dir, "files").getCanonicalPath
df.write.parquet(tempDir)
val readResult = spark.read.parquet(tempDir).collect()
.map(_.get(0).asInstanceOf[VariantVal])
assert(prepareAnswer(input) == prepareAnswer(readResult.toImmutableArraySeq))
}
}
}

Expand Down Expand Up @@ -383,14 +386,19 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval
)
cases.foreach { case (structDef, condition, parameters) =>
Seq(false, true).foreach { vectorizedReader =>
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReader.toString) {
withSQLConf(
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReader.toString,
// Invalid variant binary fails during shredding schema inference.
SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false"
) {
withTempDir { dir =>
val file = new File(dir, "dir").getCanonicalPath
val df = spark.sql(s"select $structDef as v from range(10)")
df.write.parquet(file)
val schema = StructType(Seq(StructField("v", VariantType)))
val result = spark.read.schema(schema).parquet(file).selectExpr("to_json(v)")
val e = withSQLConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "false") {
val e = withSQLConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "false",
SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") {
intercept[org.apache.spark.SparkException](result.collect())
}
checkError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share

test("timestamp physical type") {
ParquetOutputTimestampType.values.foreach { timestampParquetType =>
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> timestampParquetType.toString) {
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> timestampParquetType.toString,
SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> "true") {
withTempDir { dir =>
val schema = "t timestamp, st struct<t timestamp>, at array<timestamp>"
val fullSchema = "v struct<metadata binary, value binary, typed_value struct<" +
Expand Down Expand Up @@ -232,7 +233,8 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share
test("variant logical type annotation - ignore variant annotation") {
Seq(true, false).foreach { ignoreVariantAnnotation =>
withSQLConf(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key -> "true",
SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> ignoreVariantAnnotation.toString
SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> ignoreVariantAnnotation.toString,
SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false"
) {
withTempDir { dir =>
// write parquet file
Expand Down Expand Up @@ -302,7 +304,8 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share
"c struct<value binary, typed_value decimal(15, 1)>>>"
withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> true.toString,
SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema,
SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> true.toString) {
df.write.mode("overwrite").parquet(dir.getAbsolutePath)


Expand Down Expand Up @@ -441,7 +444,8 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share
"m map<string, struct<metadata binary, value binary>>"
withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString,
SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> true.toString,
SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) {
SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema,
SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> true.toString) {
df.write.mode("overwrite").parquet(dir.getAbsolutePath)

// Verify that we can read the full variant.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class VariantInferShreddingSuite extends QueryTest with SharedSparkSession with
super.sparkConf.set(SQLConf.PUSH_VARIANT_INTO_SCAN.key, "true")
.set(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key, "true")
.set(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key, "true")
// We cannot check the physical shredding schemas if the variant logical type annotation is
// used
.set(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key, "false")
}

private def withTempTable(tableNames: String*)(f: => Unit): Unit = {
Expand Down