diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4e9caa822997..b2bee3075f1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1593,6 +1593,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PARQUET_IGNORE_VARIANT_ANNOTATION = + buildConf("spark.sql.parquet.ignoreVariantAnnotation") + .internal() + .doc("When true, ignore the variant logical type annotation and treat the Parquet " + + "column in the same way as the underlying struct type") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + val PARQUET_FIELD_ID_READ_ENABLED = buildConf("spark.sql.parquet.fieldId.read.enabled") .doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers " + @@ -5585,7 +5594,7 @@ object SQLConf { "When false, it only reads unshredded variant.") .version("4.0.0") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val PUSH_VARIANT_INTO_SCAN = buildConf("spark.sql.variant.pushVariantIntoScan") @@ -7802,6 +7811,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def parquetAnnotateVariantLogicalType: Boolean = getConf(PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE) + def parquetIgnoreVariantAnnotation: Boolean = getConf(SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION) + def ignoreMissingParquetFieldId: Boolean = getConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID) def legacyParquetNanosAsLong: Boolean = getConf(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index d708a19dd1ac..271a1485dfd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -876,7 +876,11 @@ private[parquet] class ParquetRowConverter( } } - /** Parquet converter for unshredded Variant */ + /** + * Parquet converter for unshredded Variant. We use this converter when the + * `spark.sql.variant.allowReadingShredded` config is set to false. This option just exists to + * fall back to legacy logic which will eventually be removed. + */ private final class ParquetUnshreddedVariantConverter( parquetType: GroupType, updater: ParentContainerUpdater) @@ -890,29 +894,27 @@ private[parquet] class ParquetRowConverter( // We may allow more than two children in the future, so consider this unsupported. throw QueryCompilationErrors.invalidVariantWrongNumFieldsError() } - val valueAndMetadata = Seq("value", "metadata").map { colName => + val Seq(value, metadata) = Seq("value", "metadata").map { colName => val idx = (0 until parquetType.getFieldCount()) - .find(parquetType.getFieldName(_) == colName) - if (idx.isEmpty) { - throw QueryCompilationErrors.invalidVariantMissingFieldError(colName) - } - val child = parquetType.getType(idx.get) + .find(parquetType.getFieldName(_) == colName) + .getOrElse(throw QueryCompilationErrors.invalidVariantMissingFieldError(colName)) + val child = parquetType.getType(idx) if (!child.isPrimitive || child.getRepetition != Type.Repetition.REQUIRED || - child.asPrimitiveType().getPrimitiveTypeName != BINARY) { + child.asPrimitiveType().getPrimitiveTypeName != BINARY) { throw QueryCompilationErrors.invalidVariantNullableOrNotBinaryFieldError(colName) } - child + idx } - Array( - // Converter for value - newConverter(valueAndMetadata(0), BinaryType, new ParentContainerUpdater { + val result = new Array[Converter with HasParentContainerUpdater](2) + result(value) = + newConverter(parquetType.getType(value), BinaryType, new ParentContainerUpdater { override def set(value: Any): Unit = currentValue = value - }), - - // Converter for metadata - newConverter(valueAndMetadata(1), BinaryType, new ParentContainerUpdater { + }) + result(metadata) = + newConverter(parquetType.getType(metadata), BinaryType, new ParentContainerUpdater { override def set(value: Any): Unit = currentMetadata = value - })) + }) + result } override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index d7110c736999..9e6f4447ca79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -58,7 +58,9 @@ class ParquetToSparkSchemaConverter( caseSensitive: Boolean = SQLConf.CASE_SENSITIVE.defaultValue.get, inferTimestampNTZ: Boolean = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get, nanosAsLong: Boolean = SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get, - useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get) { + useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get, + val ignoreVariantAnnotation: Boolean = + SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.defaultValue.get) { def this(conf: SQLConf) = this( assumeBinaryIsString = conf.isParquetBinaryAsString, @@ -66,7 +68,8 @@ class ParquetToSparkSchemaConverter( caseSensitive = conf.caseSensitiveAnalysis, inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled, nanosAsLong = conf.legacyParquetNanosAsLong, - useFieldId = conf.parquetFieldIdReadEnabled) + useFieldId = conf.parquetFieldIdReadEnabled, + ignoreVariantAnnotation = conf.parquetIgnoreVariantAnnotation) def this(conf: Configuration) = this( assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean, @@ -75,7 +78,9 @@ class ParquetToSparkSchemaConverter( inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean, nanosAsLong = conf.get(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key).toBoolean, useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key, - SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get)) + SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get), + ignoreVariantAnnotation = conf.getBoolean(SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key, + SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.defaultValue.get)) /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. @@ -202,15 +207,17 @@ class ParquetToSparkSchemaConverter( case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType) case groupColumn: GroupColumnIO if targetType.contains(VariantType) => if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) { - val col = convertGroupField(groupColumn) + // We need the underlying file type regardless of the config. + val col = convertGroupField(groupColumn, ignoreVariantAnnotation = true) col.copy(sparkType = VariantType, variantFileType = Some(col)) } else { convertVariantField(groupColumn) } case groupColumn: GroupColumnIO if targetType.exists(VariantMetadata.isVariantStruct) => - val col = convertGroupField(groupColumn) + val col = convertGroupField(groupColumn, ignoreVariantAnnotation = true) col.copy(sparkType = targetType.get, variantFileType = Some(col)) - case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType) + case groupColumn: GroupColumnIO => + convertGroupField(groupColumn, ignoreVariantAnnotation, targetType) } } @@ -349,6 +356,7 @@ class ParquetToSparkSchemaConverter( private def convertGroupField( groupColumn: GroupColumnIO, + ignoreVariantAnnotation: Boolean, sparkReadType: Option[DataType] = None): ParquetColumn = { val field = groupColumn.getType.asGroupType() @@ -373,9 +381,21 @@ class ParquetToSparkSchemaConverter( Option(field.getLogicalTypeAnnotation).fold( convertInternal(groupColumn, sparkReadType.map(_.asInstanceOf[StructType]))) { - // Temporary workaround to read Shredded variant data - case v: VariantLogicalTypeAnnotation if v.getSpecVersion == 1 && sparkReadType.isEmpty => - convertInternal(groupColumn, None) + case v: VariantLogicalTypeAnnotation if v.getSpecVersion == 1 => + if (ignoreVariantAnnotation) { + convertInternal(groupColumn) + } else { + ParquetSchemaConverter.checkConversionRequirement( + sparkReadType.forall(_.isInstanceOf[VariantType]), + s"Invalid Spark read type: expected $field to be variant type but found " + + s"${if (sparkReadType.isEmpty) { "None" } else {sparkReadType.get.sql} }") + if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) { + val col = convertInternal(groupColumn) + col.copy(sparkType = VariantType, variantFileType = Some(col)) + } else { + convertVariantField(groupColumn) + } + } // A Parquet list is represented as a 3-level structure: // diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala index 1132f074f29d..ca2defffba91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala @@ -646,7 +646,9 @@ case object SparkShreddingUtils { def parquetTypeToSparkType(parquetType: ParquetType): DataType = { val messageType = ParquetTypes.buildMessage().addField(parquetType).named("foo") val column = new ColumnIOFactory().getColumnIO(messageType) - new ParquetToSparkSchemaConverter().convertField(column.getChild(0)).sparkType + // We need the underlying file type regardless of the ignoreVariantAnnotation config. + val converter = new ParquetToSparkSchemaConverter(ignoreVariantAnnotation = true) + converter.convertField(column.getChild(0)).sparkType } class SparkShreddedResult(schema: VariantSchema) extends VariantShreddingWriter.ShreddedResult { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala index a9ec5e161f34..77140c1a91ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala @@ -28,7 +28,8 @@ import org.apache.parquet.hadoop.util.HadoopInputFile import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Type} import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSparkSession @@ -160,64 +161,126 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share Seq(false, true).foreach { annotateVariantLogicalType => Seq(false, true).foreach { shredVariant => Seq(false, true).foreach { allowReadingShredded => - withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> shredVariant.toString, - SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> shredVariant.toString, - SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> - (allowReadingShredded || shredVariant).toString, - SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key -> - annotateVariantLogicalType.toString) { - def validateAnnotation(g: Type): Unit = { - if (annotateVariantLogicalType) { - assert(g.getLogicalTypeAnnotation == LogicalTypeAnnotation.variantType(1)) - } else { - assert(g.getLogicalTypeAnnotation == null) + Seq(false, true).foreach { ignoreVariantAnnotation => + withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> shredVariant.toString, + SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> shredVariant.toString, + SQLConf.VARIANT_ALLOW_READING_SHREDDED.key -> + (allowReadingShredded || shredVariant).toString, + SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key -> + annotateVariantLogicalType.toString, + SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> ignoreVariantAnnotation.toString) { + def validateAnnotation(g: Type): Unit = { + if (annotateVariantLogicalType) { + assert(g.getLogicalTypeAnnotation == LogicalTypeAnnotation.variantType(1)) + } else { + assert(g.getLogicalTypeAnnotation == null) + } + } + withTempDir { dir => + // write parquet file + val df = spark.sql( + """ + | select + | id * 2 i, + | to_variant_object(named_struct('id', id)) v, + | named_struct('i', (id * 2)::string, + | 'nv', to_variant_object(named_struct('id', 30 + id))) ns, + | array(to_variant_object(named_struct('id', 10 + id))) av, + | map('v2', to_variant_object(named_struct('id', 20 + id))) mv + | from range(0,3,1,1)""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + val file = dir.listFiles().find(_.getName.endsWith(".parquet")).get + val parquetFilePath = file.getAbsolutePath + val inputFile = HadoopInputFile.fromPath(new Path(parquetFilePath), + new Configuration()) + val reader = ParquetFileReader.open(inputFile) + val footer = reader.getFooter + val schema = footer.getFileMetaData.getSchema + val vGroup = schema.getType(schema.getFieldIndex("v")) + validateAnnotation(vGroup) + assert(vGroup.asGroupType().getFields.asScala.toSeq + .exists(_.getName == "typed_value") == shredVariant) + val nsGroup = schema.getType(schema.getFieldIndex("ns")).asGroupType() + val nvGroup = nsGroup.getType(nsGroup.getFieldIndex("nv")) + validateAnnotation(nvGroup) + val avGroup = schema.getType(schema.getFieldIndex("av")).asGroupType() + val avList = avGroup.getType(avGroup.getFieldIndex("list")).asGroupType() + val avElement = avList.getType(avList.getFieldIndex("element")) + validateAnnotation(avElement) + val mvGroup = schema.getType(schema.getFieldIndex("mv")).asGroupType() + val mvList = mvGroup.getType(mvGroup.getFieldIndex("key_value")).asGroupType() + val mvValue = mvList.getType(mvList.getFieldIndex("value")) + validateAnnotation(mvValue) + // verify result + val result = spark.read.format("parquet") + .schema("v variant, ns struct, av array, " + + "mv map") + .load(dir.getAbsolutePath) + .selectExpr("v:id::int i1", "ns.nv:id::int i2", "av[0]:id::int i3", + "mv['v2']:id::int i4") + checkAnswer(result, Array(Row(0, 30, 10, 20), Row(1, 31, 11, 21), + Row(2, 32, 12, 22))) + reader.close() } } - withTempDir { dir => - // write parquet file - val df = spark.sql( - """ - | select - | id * 2 i, - | to_variant_object(named_struct('id', id)) v, - | named_struct('i', (id * 2)::string, - | 'nv', to_variant_object(named_struct('id', 30 + id))) ns, - | array(to_variant_object(named_struct('id', 10 + id))) av, - | map('v2', to_variant_object(named_struct('id', 20 + id))) mv - | from range(0,3,1,1)""".stripMargin) - df.write.mode("overwrite").parquet(dir.getAbsolutePath) - val file = dir.listFiles().find(_.getName.endsWith(".parquet")).get - val parquetFilePath = file.getAbsolutePath - val inputFile = HadoopInputFile.fromPath(new Path(parquetFilePath), - new Configuration()) - val reader = ParquetFileReader.open(inputFile) - val footer = reader.getFooter - val schema = footer.getFileMetaData.getSchema - val vGroup = schema.getType(schema.getFieldIndex("v")) - validateAnnotation(vGroup) - assert(vGroup.asGroupType().getFields.asScala.toSeq - .exists(_.getName == "typed_value") == shredVariant) - val nsGroup = schema.getType(schema.getFieldIndex("ns")).asGroupType() - val nvGroup = nsGroup.getType(nsGroup.getFieldIndex("nv")) - validateAnnotation(nvGroup) - val avGroup = schema.getType(schema.getFieldIndex("av")).asGroupType() - val avList = avGroup.getType(avGroup.getFieldIndex("list")).asGroupType() - val avElement = avList.getType(avList.getFieldIndex("element")) - validateAnnotation(avElement) - val mvGroup = schema.getType(schema.getFieldIndex("mv")).asGroupType() - val mvList = mvGroup.getType(mvGroup.getFieldIndex("key_value")).asGroupType() - val mvValue = mvList.getType(mvList.getFieldIndex("value")) - validateAnnotation(mvValue) - // verify result - val result = spark.read.format("parquet") - .schema("v variant, ns struct, av array, " + - "mv map") - .load(dir.getAbsolutePath) - .selectExpr("v:id::int i1", "ns.nv:id::int i2", "av[0]:id::int i3", - "mv['v2']:id::int i4") - checkAnswer(result, Array(Row(0, 30, 10, 20), Row(1, 31, 11, 21), Row(2, 32, 12, 22))) - reader.close() + } + } + } + } + } + + test("variant logical type annotation - ignore variant annotation") { + Seq(true, false).foreach { ignoreVariantAnnotation => + withSQLConf(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key -> "true", + SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key -> ignoreVariantAnnotation.toString + ) { + withTempDir { dir => + // write parquet file + val df = spark.sql( + """ + | select + | id * 2 i, + | 1::variant v, + | named_struct('i', (id * 2)::string, 'nv', 1::variant) ns, + | array(1::variant) av, + | map('v2', 1::variant) mv + | from range(0,1,1,1)""".stripMargin) + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + // verify result + val normal_result = spark.read.format("parquet") + .schema("v variant, ns struct, av array, " + + "mv map") + .load(dir.getAbsolutePath) + .selectExpr("v::int i1", "ns.nv::int i2", "av[0]::int i3", + "mv['v2']::int i4") + checkAnswer(normal_result, Array(Row(1, 1, 1, 1))) + val struct_result = spark.read.format("parquet") + .schema("v struct, " + + "ns struct>, " + + "av array>, " + + "mv map>") + .load(dir.getAbsolutePath) + .selectExpr("v", "ns.nv", "av[0]", "mv['v2']") + if (ignoreVariantAnnotation) { + checkAnswer( + struct_result, + Seq(Row( + Row(Array[Byte](12, 1), Array[Byte](1, 0, 0)), + Row(Array[Byte](12, 1), Array[Byte](1, 0, 0)), + Row(Array[Byte](12, 1), Array[Byte](1, 0, 0)), + Row(Array[Byte](12, 1), Array[Byte](1, 0, 0)) + )) + ) + } else { + val exception = intercept[SparkException]{ + struct_result.collect() } + checkError( + exception = exception.getCause.asInstanceOf[AnalysisException], + condition = "_LEGACY_ERROR_TEMP_3071", + parameters = Map("msg" -> "Invalid Spark read type[\\s\\S]*"), + matchPVals = true + ) } } }