Skip to content

Commit

Permalink
[SPARK-46152][SQL] XML: Add DecimalType support in XML schema inference
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add DecimalType support in XML schema inference

### Why are the changes needed?
Add DecimalType support in XML schema inference

### Does this PR introduce _any_ user-facing change?
Yes, XML schema inference will infer DecimalType correctly.

### How was this patch tested?
Added a new unit test

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #44069 from sandip-db/xml-decimalType.

Authored-by: Sandip Agarwala <131817656+sandip-db@users.noreply.github.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
sandip-db authored and HyukjinKwon committed Nov 30, 2023
1 parent 350af64 commit cfcc250
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 13 deletions.
Expand Up @@ -31,6 +31,7 @@ import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.util.{DateFormatter, PermissiveMode, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
import org.apache.spark.sql.types._
Expand All @@ -39,6 +40,8 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean)
extends Serializable
with Logging {

private val decimalParser = ExprUtils.getDecimalParser(options.locale)

private val timestampFormatter = TimestampFormatter(
options.timestampFormatInRead,
options.zoneId,
Expand Down Expand Up @@ -132,11 +135,12 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean)
}

if (options.inferSchema) {
lazy val decimalTry = tryParseDecimal(value)
value match {
case null => NullType
case v if v.isEmpty => NullType
case v if isLong(v) => LongType
case v if isInteger(v) => IntegerType
case v if options.prefersDecimal && decimalTry.isDefined => decimalTry.get
case v if isDouble(v) => DoubleType
case v if isBoolean(v) => BooleanType
case v if isDate(v) => DateType
Expand Down Expand Up @@ -305,40 +309,63 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean)
}
}

private def isDouble(value: String): Boolean = {
private def tryParseDecimal(value: String): Option[DataType] = {
val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) {
value.substring(1)
} else {
value
}
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a double. All built-in formats will start with a digit or period.
// the input isn't a decimal. All built-in formats will start with a digit or period.
if (signSafeValue.isEmpty ||
!(Character.isDigit(signSafeValue.head) || signSafeValue.head == '.')) {
return false
return None
}
// Rule out strings ending in D or F, as they will parse as double but should be disallowed
if (value.nonEmpty && (value.last match {
if (signSafeValue.last match {
case 'd' | 'D' | 'f' | 'F' => true
case _ => false
})) {
return false
}) {
return None
}
(allCatch opt signSafeValue.toDouble).isDefined

try {
// The conversion can fail when the `field` is not a form of number.
val bigDecimal = decimalParser(signSafeValue)
// Because many other formats do not support decimal, it reduces the cases for
// decimals by disallowing values having scale (e.g. `1.1`).
if (bigDecimal.scale <= 0) {
// `DecimalType` conversion can fail when
// 1. The precision is bigger than 38.
// 2. scale is bigger than precision.
return Some(DecimalType(bigDecimal.precision, bigDecimal.scale))
}
} catch {
case _ : Exception =>
}
None
}

private def isInteger(value: String): Boolean = {
private def isDouble(value: String): Boolean = {
val signSafeValue = if (value.startsWith("+") || value.startsWith("-")) {
value.substring(1)
} else {
value
}
// A little shortcut to avoid trying many formatters in the common case that
// the input isn't a number. All built-in formats will start with a digit.
if (signSafeValue.isEmpty || !Character.isDigit(signSafeValue.head)) {
// the input isn't a double. All built-in formats will start with a digit or period.
if (signSafeValue.isEmpty ||
!(Character.isDigit(signSafeValue.head) || signSafeValue.head == '.')) {
return false
}
(allCatch opt signSafeValue.toInt).isDefined
// Rule out strings ending in D or F, as they will parse as double but should be disallowed
if (signSafeValue.last match {
case 'd' | 'D' | 'f' | 'F' => true
case _ => false
}) {
return false
}
(allCatch opt signSafeValue.toDouble).isDefined
}

private def isLong(value: String): Boolean = {
Expand Down
Expand Up @@ -102,6 +102,8 @@ class XmlOptions(
val wildcardColName =
parameters.getOrElse(WILDCARD_COL_NAME, XmlOptions.DEFAULT_WILDCARD_COL_NAME)
val ignoreNamespace = getBool(IGNORE_NAMESPACE, false)
val prefersDecimal =
parameters.get(PREFERS_DECIMAL).map(_.toBoolean).getOrElse(false)
// setting indent to "" disables indentation in the generated XML.
// Each row will be written in a new line.
val indent = parameters.getOrElse(INDENT, DEFAULT_INDENT)
Expand Down Expand Up @@ -207,6 +209,7 @@ object XmlOptions extends DataSourceOptions {
val TIMESTAMP_NTZ_FORMAT = newOption("timestampNTZFormat")
val TIME_ZONE = newOption("timeZone")
val INDENT = newOption("indent")
val PREFERS_DECIMAL = newOption("prefersDecimal")
// Options with alternative
val ENCODING = "encoding"
val CHARSET = "charset"
Expand Down
Expand Up @@ -32,7 +32,7 @@ import org.apache.hadoop.io.{LongWritable, Text}
import org.apache.hadoop.io.compress.GzipCodec

import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, QueryTest, Row, SaveMode}
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoders, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.xml.XmlOptions
import org.apache.spark.sql.catalyst.xml.XmlOptions._
Expand Down Expand Up @@ -2115,4 +2115,67 @@ class XmlSuite extends QueryTest with SharedSparkSession {
}
}
}

def testWriteReadRoundTrip(df: DataFrame,
options: Map[String, String] = Map.empty): Unit = {
withTempDir { dir =>
df.write
.options(options)
.option("rowTag", "ROW")
.mode("overwrite")
.xml(dir.getCanonicalPath)
val df2 = spark.read
.options(options)
.option("rowTag", "ROW")
.xml(dir.getCanonicalPath)
checkAnswer(df, df2)
}
}

def primitiveFieldAndType: Dataset[String] =
spark.createDataset(spark.sparkContext.parallelize("""
<ROW>
<string>this is a simple string.</string>
<integer>10</integer>
<long>21474836470</long>
<decimal>92233720368547758070</decimal>
<double>1.7976931348623157</double>
<boolean>true</boolean>
<null>null</null>
</ROW>""" :: Nil))(Encoders.STRING)

test("Primitive field and type inferring") {
val dfWithNodecimal = spark.read
.option("nullValue", "null")
.xml(primitiveFieldAndType)
assert(dfWithNodecimal.schema("decimal").dataType === DoubleType)

val df = spark.read
.option("nullValue", "null")
.option("prefersDecimal", "true")
.xml(primitiveFieldAndType)

val expectedSchema = StructType(
StructField("boolean", BooleanType, true) ::
StructField("decimal", DecimalType(20, 0), true) ::
StructField("double", DoubleType, true) ::
StructField("integer", LongType, true) ::
StructField("long", LongType, true) ::
StructField("null", StringType, true) ::
StructField("string", StringType, true) :: Nil)

assert(df.schema === expectedSchema)

checkAnswer(
df,
Row(true,
new java.math.BigDecimal("92233720368547758070"),
1.7976931348623157,
10,
21474836470L,
null,
"this is a simple string.")
)
testWriteReadRoundTrip(df, Map("nullValue" -> "null", "prefersDecimal" -> "true"))
}
}

0 comments on commit cfcc250

Please sign in to comment.