From 23878184cf66765d2df1b1053f778402424bdbc5 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Wed, 26 Nov 2025 00:14:21 +0000 Subject: [PATCH 1/2] shredding test fixes and column default fizx --- .../spark/sql/catalyst/expressions/Cast.scala | 25 ++++++++ .../util/ResolveDefaultColumnsUtil.scala | 2 +- .../org/apache/spark/sql/VariantSuite.scala | 64 +++++++++++-------- .../ParquetVariantShreddingSuite.scala | 12 ++-- .../parquet/VariantInferShreddingSuite.scala | 3 + 5 files changed, 73 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1f2805ec2789..1162a5394221 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -360,6 +360,31 @@ object Cast extends QueryErrorsBase { */ def canUpCast(from: DataType, to: DataType): Boolean = UpCastRule.canUpCast(from, to) + /** + * Returns true iff it is safe to provide a default value of `from` type typically defined in the + * data source metadata to the `to` type typically in the read schema of a query. + */ + def canAssignDefaultValue(from: DataType, to: DataType): Boolean = { + def isVariantStruct(st: StructType): Boolean = { + st.fields.length > 0 && st.fields.forall(_.metadata.contains("__VARIANT_METADATA_KEY")) + } + (from, to) match { + case (s1: StructType, s2: StructType) => + s1.length == s2.length && s1.fields.zip(s2.fields).forall { + case (f1, f2) => resolvableNullability(f1.nullable, f2.nullable) && + canAssignDefaultValue(f1.dataType, f2.dataType) + } + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + resolvableNullability(fn, tn) && canAssignDefaultValue(fromType, toType) + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + resolvableNullability(fn, tn) && canAssignDefaultValue(fromKey, toKey) && + canAssignDefaultValue(fromValue, toValue) + // A VARIANT field can be read as StructType due to shredding. + case (VariantType, s: StructType) => isVariantStruct(s) + case _ => canUpCast(from, to) + } + } + /** * Returns true iff we can cast the `from` type to `to` type as per the ANSI SQL. * In practice, the behavior is mostly the same as PostgreSQL. It disallows certain unreasonable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 4bef21d0a091..488d1acf43ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -480,7 +480,7 @@ object ResolveDefaultColumns extends QueryErrorsBase val ret = analyzed match { case equivalent if equivalent.dataType == supplanted => equivalent - case canUpCast if Cast.canUpCast(canUpCast.dataType, supplanted) => + case _ if Cast.canAssignDefaultValue(analyzed.dataType, supplanted) => Cast(analyzed, supplanted, Some(conf.sessionLocalTimeZone)) case other => defaultValueFromWiderTypeLiteral(other, supplanted, colName).getOrElse( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index ac6a4e435709..141aa16b8dc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -197,36 +197,39 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval } test("round trip tests") { - val rand = new Random(42) - val input = Seq.fill(50) { - if (rand.nextInt(10) == 0) { - null - } else { - val value = new Array[Byte](rand.nextInt(50)) - rand.nextBytes(value) - val metadata = new Array[Byte](rand.nextInt(50)) - rand.nextBytes(metadata) - // Generate a valid metadata, otherwise the shredded reader will fail. - new VariantVal(value, Array[Byte](VERSION, 0, 0) ++ metadata) + withSQLConf(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false") { + val rand = new Random(42) + val input = Seq.fill(50) { + if (rand.nextInt(10) == 0) { + null + } else { + val value = new Array[Byte](rand.nextInt(50)) + rand.nextBytes(value) + val metadata = new Array[Byte](rand.nextInt(50)) + rand.nextBytes(metadata) + // Generate a valid metadata, otherwise the shredded reader will fail. + new VariantVal(value, Array[Byte](VERSION, 0, 0) ++ metadata) + } } - } - val df = spark.createDataFrame( - spark.sparkContext.parallelize(input.map(Row(_))), - StructType.fromDDL("v variant") - ) - val result = df.collect().map(_.get(0).asInstanceOf[VariantVal]) + val df = spark.createDataFrame( + spark.sparkContext.parallelize(input.map(Row(_))), + StructType.fromDDL("v variant") + ) + val result = df.collect().map(_.get(0).asInstanceOf[VariantVal]) - def prepareAnswer(values: Seq[VariantVal]): Seq[String] = { - values.map(v => if (v == null) "null" else v.debugString()).sorted - } - assert(prepareAnswer(input) == prepareAnswer(result.toImmutableArraySeq)) + def prepareAnswer(values: Seq[VariantVal]): Seq[String] = { + values.map(v => if (v == null) "null" else v.debugString()).sorted + } + assert(prepareAnswer(input) == prepareAnswer(result.toImmutableArraySeq)) - withTempDir { dir => - val tempDir = new File(dir, "files").getCanonicalPath - df.write.parquet(tempDir) - val readResult = spark.read.parquet(tempDir).collect().map(_.get(0).asInstanceOf[VariantVal]) - assert(prepareAnswer(input) == prepareAnswer(readResult.toImmutableArraySeq)) + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + df.write.parquet(tempDir) + val readResult = spark.read.parquet(tempDir).collect() + .map(_.get(0).asInstanceOf[VariantVal]) + assert(prepareAnswer(input) == prepareAnswer(readResult.toImmutableArraySeq)) + } } } @@ -383,14 +386,19 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval ) cases.foreach { case (structDef, condition, parameters) => Seq(false, true).foreach { vectorizedReader => - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReader.toString) { + withSQLConf( + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReader.toString, + // Invalid variant binary fails during shredding schema inference + SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false" + ) { withTempDir { dir => val file = new File(dir, "dir").getCanonicalPath val df = spark.sql(s"select $structDef as v from range(10)") df.write.parquet(file) val schema = StructType(Seq(StructField("v", VariantType))) val result = spark.read.schema(schema).parquet(file).selectExpr("to_json(v)") - val e = withSQLConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "false") { + val e = withSQLConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> "false", + SQLConf.PUSH_VARIANT_INTO_SCAN.key -> "false") { intercept[org.apache.spark.SparkException](result.collect()) } checkError( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala index 77140c1a91ee..1f06ddb29bd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala @@ -48,7 +48,8 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share test("timestamp physical type") { ParquetOutputTimestampType.values.foreach { timestampParquetType => - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> timestampParquetType.toString) { + withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> timestampParquetType.toString, + SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> "true") { withTempDir { dir => val schema = "t timestamp, st struct, at array" val fullSchema = "v struct withSQLConf(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key -> "true", - SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> ignoreVariantAnnotation.toString + SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> ignoreVariantAnnotation.toString, + SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false" ) { withTempDir { dir => // write parquet file @@ -302,7 +304,8 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share "c struct>>" withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString, SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> true.toString, - SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) { + SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema, + SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> true.toString) { df.write.mode("overwrite").parquet(dir.getAbsolutePath) @@ -441,7 +444,8 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share "m map>" withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> true.toString, SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> true.toString, - SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema) { + SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> schema, + SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> true.toString) { df.write.mode("overwrite").parquet(dir.getAbsolutePath) // Verify that we can read the full variant. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala index cdaf6c488dc2..49a43fffafb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/VariantInferShreddingSuite.scala @@ -41,6 +41,9 @@ class VariantInferShreddingSuite extends QueryTest with SharedSparkSession with super.sparkConf.set(SQLConf.PUSH_VARIANT_INTO_SCAN.key, "true") .set(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key, "true") .set(SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key, "true") + // We cannot check the physical shredding schemas if the variant logical type annotation is + // used + .set(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key, "false") } private def withTempTable(tableNames: String*)(f: => Unit): Unit = { From c2bc63f466b88e0212914a50ea54fba0fec30f2a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 26 Nov 2025 12:48:12 +0800 Subject: [PATCH 2/2] Update sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala --- sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala index 141aa16b8dc2..16be9558409c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantSuite.scala @@ -388,7 +388,7 @@ class VariantSuite extends QueryTest with SharedSparkSession with ExpressionEval Seq(false, true).foreach { vectorizedReader => withSQLConf( SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorizedReader.toString, - // Invalid variant binary fails during shredding schema inference + // Invalid variant binary fails during shredding schema inference. SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> "false" ) { withTempDir { dir =>