diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala index b17e89b536103..00497c1c31f35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala @@ -544,6 +544,11 @@ class StaxXmlParser( case ShortType => castTo(value, ShortType) case IntegerType => signSafeToInt(value) case dt: DecimalType => castTo(value, dt) + case VariantType => + val builder = new VariantBuilder(false) + StaxXmlParser.appendXMLCharacterToVariant(builder, value, options) + val v = builder.result() + new VariantVal(v.getValue, v.getMetadata) case _ => throw new SparkIllegalArgumentException( errorClass = "_LEGACY_ERROR_TEMP_3246", messageParameters = Map("dataType" -> dataType.toString)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala index 505aa15b34872..63b816ad6b53a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlVariantSuite.scala @@ -418,7 +418,9 @@ class XmlVariantSuite extends QueryTest with SharedSparkSession with TestXmlData val df = createDSLDataFrame( fileName = "books-complicated.xml", schemaDDL = Some( - "_id string, author string, title string, " + + "_id variant, " + // Attribute as variant + "author string, " + + "title string, " + "genre struct, " + // Struct with variant "price variant, " + // Scalar as variant "publish_dates struct>" // Array with variant @@ -427,11 +429,16 @@ class XmlVariantSuite extends QueryTest with SharedSparkSession with TestXmlData ) checkAnswer( df.select( + variant_get(col("_id"), "$", "string"), variant_get(col("genre.name"), "$", "string"), variant_get(col("price"), "$", "double"), variant_get(col("publish_dates.publish_date").getItem(0), "$.month", "int") ), - Seq(Row("Computer", 44.95, 10), Row("Fantasy", 5.95, 12), Row("Fantasy", null, 11)) + Seq( + Row("bk101", "Computer", 44.95, 10), + Row("bk102", "Fantasy", 5.95, 12), + Row("bk103", "Fantasy", null, 11) + ) ) }