Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-45399][SQL] XML: Add XML Options using newOption #43201

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -35,7 +35,6 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, BadRecordException
import org.apache.spark.sql.catalyst.xml.StaxXmlParser.convertStream
import org.apache.spark.sql.catalyst.xml.TypeCast._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -47,24 +46,6 @@ class StaxXmlParser(

private val factory = options.buildXmlFactory()

// Flags to signal if we need to fall back to the backward compatible behavior of parsing
// dates and timestamps.
// For more information, see comments for "enableDateTimeParsingFallback" option in XmlOptions.
private val enableParsingFallbackForTimestampType =
options.enableDateTimeParsingFallback
.orElse(SQLConf.get.jsonEnableDateTimeParsingFallback)
.getOrElse {
SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY ||
options.timestampFormatInRead.isEmpty
}
private val enableParsingFallbackForDateType =
options.enableDateTimeParsingFallback
.orElse(SQLConf.get.jsonEnableDateTimeParsingFallback)
.getOrElse {
SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY ||
options.dateFormatInRead.isEmpty
}

/**
* Parses a single XML string and turns it into either one resulting row or no row (if the
* the record is malformed).
Expand Down
Expand Up @@ -22,10 +22,7 @@ import java.util.Locale

import scala.util.Try
import scala.util.control.Exception._
import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -85,25 +82,7 @@ private[sql] object TypeCast {
}

private def parseXmlTimestamp(value: String, options: XmlOptions): Long = {
try {
options.timestampFormatter.parse(value)
} catch {
case NonFatal(e) =>
// If fails to parse, then tries the way used in 2.0 and 1.x for backwards
// compatibility if enabled.
val enableParsingFallbackForTimestampType =
options.enableDateTimeParsingFallback
.orElse(SQLConf.get.jsonEnableDateTimeParsingFallback)
.getOrElse {
SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY ||
options.timestampFormatInRead.isEmpty
}
if (!enableParsingFallbackForTimestampType) {
throw e
}
val str = DateTimeUtils.cleanLegacyTimestampStr(UTF8String.fromString(value))
DateTimeUtils.stringToTimestamp(str, options.zoneId).getOrElse(throw e)
}
options.timestampFormatter.parse(value)
}

// TODO: This function unnecessarily does type dispatch. Should merge it with `castTo`.
Expand Down
Expand Up @@ -42,7 +42,7 @@ private[sql] class XmlOptions(
def this(
parameters: Map[String, String] = Map.empty,
defaultTimeZoneId: String = SQLConf.get.sessionLocalTimeZone,
defaultColumnNameOfCorruptRecord: String = "") = {
defaultColumnNameOfCorruptRecord: String = SQLConf.get.columnNameOfCorruptRecord) = {
this(
CaseInsensitiveMap(parameters),
defaultTimeZoneId,
Expand All @@ -62,42 +62,39 @@ private[sql] class XmlOptions(
}
}

val compressionCodec = parameters.get("compression").orElse(parameters.get("codec"))
.map(CompressionCodecs.getCodecClassName)
val rowTag = parameters.getOrElse("rowTag", XmlOptions.DEFAULT_ROW_TAG)
require(rowTag.nonEmpty, "'rowTag' option should not be empty string.")
val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG)
require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be empty string.")
require(!rowTag.startsWith("<") && !rowTag.endsWith(">"),
"'rowTag' should not include angle brackets")
val rootTag = parameters.getOrElse("rootTag", XmlOptions.DEFAULT_ROOT_TAG)
s"'$ROW_TAG' should not include angle brackets")
val rootTag = parameters.getOrElse(ROOT_TAG, XmlOptions.DEFAULT_ROOT_TAG)
require(!rootTag.startsWith("<") && !rootTag.endsWith(">"),
"'rootTag' should not include angle brackets")
val declaration = parameters.getOrElse("declaration", XmlOptions.DEFAULT_DECLARATION)
s"'$ROOT_TAG' should not include angle brackets")
val declaration = parameters.getOrElse(DECLARATION, XmlOptions.DEFAULT_DECLARATION)
require(!declaration.startsWith("<") && !declaration.endsWith(">"),
"'declaration' should not include angle brackets")
val arrayElementName = parameters.getOrElse("arrayElementName",
s"'$DECLARATION' should not include angle brackets")
val arrayElementName = parameters.getOrElse(ARRAY_ELEMENT_NAME,
XmlOptions.DEFAULT_ARRAY_ELEMENT_NAME)
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val excludeAttributeFlag = parameters.get("excludeAttribute").map(_.toBoolean).getOrElse(false)
val treatEmptyValuesAsNulls =
parameters.get("treatEmptyValuesAsNulls").map(_.toBoolean).getOrElse(false)
val samplingRatio = parameters.get(SAMPLING_RATIO).map(_.toDouble).getOrElse(1.0)
require(samplingRatio > 0, s"$SAMPLING_RATIO ($samplingRatio) should be greater than 0")
val excludeAttributeFlag = getBool(EXCLUDE_ATTRIBUTE, false)
val treatEmptyValuesAsNulls = getBool(TREAT_EMPTY_VALUE_AS_NULLS, false)
val attributePrefix =
parameters.getOrElse("attributePrefix", XmlOptions.DEFAULT_ATTRIBUTE_PREFIX)
val valueTag = parameters.getOrElse("valueTag", XmlOptions.DEFAULT_VALUE_TAG)
require(valueTag.nonEmpty, "'valueTag' option should not be empty string.")
parameters.getOrElse(ATTRIBUTE_PREFIX, XmlOptions.DEFAULT_ATTRIBUTE_PREFIX)
val valueTag = parameters.getOrElse(VALUE_TAG, XmlOptions.DEFAULT_VALUE_TAG)
require(valueTag.nonEmpty, s"'$VALUE_TAG' option should not be empty string.")
require(valueTag != attributePrefix,
"'valueTag' and 'attributePrefix' options should not be the same.")
val nullValue = parameters.getOrElse("nullValue", XmlOptions.DEFAULT_NULL_VALUE)
s"'$VALUE_TAG' and '$ATTRIBUTE_PREFIX' options should not be the same.")
val nullValue = parameters.getOrElse(NULL_VALUE, XmlOptions.DEFAULT_NULL_VALUE)
val columnNameOfCorruptRecord =
parameters.getOrElse("columnNameOfCorruptRecord", "_corrupt_record")
val ignoreSurroundingSpaces =
parameters.get("ignoreSurroundingSpaces").map(_.toBoolean).getOrElse(false)
val parseMode = ParseMode.fromString(parameters.getOrElse("mode", PermissiveMode.name))
val inferSchema = parameters.get("inferSchema").map(_.toBoolean).getOrElse(true)
val rowValidationXSDPath = parameters.get("rowValidationXSDPath").orNull
parameters.getOrElse(COLUMN_NAME_OF_CORRUPT_RECORD, defaultColumnNameOfCorruptRecord)
val ignoreSurroundingSpaces = getBool(IGNORE_SURROUNDING_SPACES, false)
val parseMode = ParseMode.fromString(parameters.getOrElse(MODE, PermissiveMode.name))
val inferSchema = getBool(INFER_SCHEMA, true)
val rowValidationXSDPath = parameters.get(ROW_VALIDATION_XSD_PATH).orNull
val wildcardColName =
parameters.getOrElse("wildcardColName", XmlOptions.DEFAULT_WILDCARD_COL_NAME)
val ignoreNamespace = parameters.get("ignoreNamespace").map(_.toBoolean).getOrElse(false)
parameters.getOrElse(WILDCARD_COL_NAME, XmlOptions.DEFAULT_WILDCARD_COL_NAME)
val ignoreNamespace = getBool(IGNORE_NAMESPACE, false)

/**
* Infer columns with all valid date entries as date type (otherwise inferred as string or
Expand Down Expand Up @@ -142,17 +139,6 @@ private[sql] class XmlOptions(
s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]"
})

// SPARK-39731: Enables the backward compatible parsing behavior.
// Generally, this config should be set to false to avoid producing potentially incorrect results
// which is the current default (see JacksonParser).
//
// If enabled and the date cannot be parsed, we will fall back to `DateTimeUtils.stringToDate`.
// If enabled and the timestamp cannot be parsed, `DateTimeUtils.stringToTimestamp` will be used.
// Otherwise, depending on the parser policy and a custom pattern, an exception may be thrown and
// the value will be parsed as null.
val enableDateTimeParsingFallback: Option[Boolean] =
parameters.get(ENABLE_DATETIME_PARSING_FALLBACK).map(_.toBoolean)

val timezone = parameters.get("timezone")

val zoneId: ZoneId = DateTimeUtils.getZoneId(
Expand Down Expand Up @@ -207,19 +193,34 @@ private[sql] object XmlOptions extends DataSourceOptions {
val DEFAULT_CHARSET: String = StandardCharsets.UTF_8.name
val DEFAULT_NULL_VALUE: String = null
val DEFAULT_WILDCARD_COL_NAME = "xs_any"
val ROW_TAG = newOption("rowTag")
val ROOT_TAG = newOption("rootTag")
val DECLARATION = newOption("declaration")
val ARRAY_ELEMENT_NAME = newOption("arrayElementName")
val EXCLUDE_ATTRIBUTE = newOption("excludeAttribute")
val TREAT_EMPTY_VALUE_AS_NULLS = newOption("treatEmptyValuesAsNulls")
val ATTRIBUTE_PREFIX = newOption("attributePrefix")
val VALUE_TAG = newOption("valueTag")
val NULL_VALUE = newOption("nullValue")
val IGNORE_SURROUNDING_SPACES = newOption("ignoreSurroundingSpaces")
val ROW_VALIDATION_XSD_PATH = newOption("rowValidationXSDPath")
val WILDCARD_COL_NAME = newOption("wildcardColName")
val IGNORE_NAMESPACE = newOption("ignoreNamespace")
val INFER_SCHEMA = newOption("inferSchema")
val PREFER_DATE = newOption("preferDate")
val MODE = newOption("mode")
val LOCALE = newOption("locale")
val COMPRESSION = newOption("compression")
val ENABLE_DATETIME_PARSING_FALLBACK = newOption("enableDateTimeParsingFallback")
val MULTI_LINE = newOption("multiLine")
val SAMPLING_RATIO = newOption("samplingRatio")
val COLUMN_NAME_OF_CORRUPT_RECORD = newOption("columnNameOfCorruptRecord")
val DATE_FORMAT = newOption("dateFormat")
val TIMESTAMP_FORMAT = newOption("timestampFormat")
val TIME_ZONE = newOption("timeZone")
// Options with alternative
val ENCODING = "encoding"
val CHARSET = "charset"
newOption(ENCODING, CHARSET)
val TIME_ZONE = "timezone"
newOption(DateTimeUtils.TIMEZONE_OPTION, TIME_ZONE)

def apply(parameters: Map[String, String]): XmlOptions =
new XmlOptions(parameters, SQLConf.get.sessionLocalTimeZone)
Expand Down
Expand Up @@ -332,7 +332,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
val cars = spark.read.xml(getTestResourcePath(resDir + "cars.xml"))
cars.write
.mode(SaveMode.Overwrite)
.options(Map("codec" -> classOf[GzipCodec].getName))
.options(Map("compression" -> classOf[GzipCodec].getName))
.xml(copyFilePath.toString)
// Check that the part file has a .gz extension
assert(Files.list(copyFilePath).iterator().asScala
Expand Down