-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-24774][SQL] Avro: Support logical decimal type #22037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,19 +18,28 @@ | |
| package org.apache.spark.sql.avro | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.util.Random | ||
|
|
||
| import com.fasterxml.jackson.annotation.ObjectIdGenerators.UUIDGenerator | ||
| import org.apache.avro.{LogicalType, LogicalTypes, Schema, SchemaBuilder} | ||
| import org.apache.avro.LogicalTypes.{Date, TimestampMicros, TimestampMillis} | ||
| import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis} | ||
| import org.apache.avro.Schema.Type._ | ||
|
|
||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator | ||
| import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.sql.types.Decimal.{maxPrecisionForBytes, minBytesForPrecision} | ||
|
|
||
| /** | ||
| * This object contains method that are used to convert sparkSQL schemas to avro schemas and vice | ||
| * versa. | ||
| */ | ||
| object SchemaConverters { | ||
| private lazy val uuidGenerator = RandomUUIDGenerator(new Random().nextLong()) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems this is an unused value.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I made a PR to clean it up: #34472 |
||
|
|
||
| private lazy val nullSchema = Schema.create(Schema.Type.NULL) | ||
|
|
||
| case class SchemaType(dataType: DataType, nullable: Boolean) | ||
|
|
||
| /** | ||
|
|
@@ -44,14 +53,20 @@ object SchemaConverters { | |
| } | ||
| case STRING => SchemaType(StringType, nullable = false) | ||
| case BOOLEAN => SchemaType(BooleanType, nullable = false) | ||
| case BYTES => SchemaType(BinaryType, nullable = false) | ||
| case BYTES | FIXED => avroSchema.getLogicalType match { | ||
| // For FIXED type, if the precision requires more bytes than fixed size, the logical | ||
| // type will be null, which is handled by Avro library. | ||
| case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false) | ||
| case _ => SchemaType(BinaryType, nullable = false) | ||
| } | ||
|
|
||
| case DOUBLE => SchemaType(DoubleType, nullable = false) | ||
| case FLOAT => SchemaType(FloatType, nullable = false) | ||
| case LONG => avroSchema.getLogicalType match { | ||
| case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false) | ||
| case _ => SchemaType(LongType, nullable = false) | ||
| } | ||
| case FIXED => SchemaType(BinaryType, nullable = false) | ||
|
|
||
| case ENUM => SchemaType(StringType, nullable = false) | ||
|
|
||
| case RECORD => | ||
|
|
@@ -114,32 +129,36 @@ object SchemaConverters { | |
| prevNameSpace: String = "", | ||
| outputTimestampType: AvroOutputTimestampType.Value = AvroOutputTimestampType.TIMESTAMP_MICROS) | ||
| : Schema = { | ||
| val builder = if (nullable) { | ||
| SchemaBuilder.builder().nullable() | ||
| } else { | ||
| SchemaBuilder.builder() | ||
| } | ||
| val builder = SchemaBuilder.builder() | ||
|
|
||
| catalystType match { | ||
| val schema = catalystType match { | ||
| case BooleanType => builder.booleanType() | ||
| case ByteType | ShortType | IntegerType => builder.intType() | ||
| case LongType => builder.longType() | ||
| case DateType => builder | ||
| .intBuilder() | ||
| .prop(LogicalType.LOGICAL_TYPE_PROP, LogicalTypes.date().getName) | ||
| .endInt() | ||
| case DateType => | ||
| LogicalTypes.date().addToSchema(builder.intType()) | ||
| case TimestampType => | ||
| val timestampType = outputTimestampType match { | ||
| case AvroOutputTimestampType.TIMESTAMP_MILLIS => LogicalTypes.timestampMillis() | ||
| case AvroOutputTimestampType.TIMESTAMP_MICROS => LogicalTypes.timestampMicros() | ||
| case other => | ||
| throw new IncompatibleSchemaException(s"Unexpected output timestamp type $other.") | ||
| } | ||
| builder.longBuilder().prop(LogicalType.LOGICAL_TYPE_PROP, timestampType.getName).endLong() | ||
| timestampType.addToSchema(builder.longType()) | ||
|
|
||
| case FloatType => builder.floatType() | ||
| case DoubleType => builder.doubleType() | ||
| case _: DecimalType | StringType => builder.stringType() | ||
| case StringType => builder.stringType() | ||
| case d: DecimalType => | ||
| val avroType = LogicalTypes.decimal(d.precision, d.scale) | ||
| val fixedSize = minBytesForPrecision(d.precision) | ||
| // Need to avoid naming conflict for the fixed fields | ||
| val name = prevNameSpace match { | ||
| case "" => s"$recordName.fixed" | ||
| case _ => s"$prevNameSpace.$recordName.fixed" | ||
| } | ||
| avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize)) | ||
|
|
||
| case BinaryType => builder.bytesType() | ||
| case ArrayType(et, containsNull) => | ||
| builder.array() | ||
|
|
@@ -164,6 +183,11 @@ object SchemaConverters { | |
| // This should never happen. | ||
| case other => throw new IncompatibleSchemaException(s"Unexpected type $other.") | ||
| } | ||
| if (nullable) { | ||
| Schema.createUnion(schema, nullSchema) | ||
| } else { | ||
| schema | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,8 @@ import java.util.{TimeZone, UUID} | |
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.avro.Schema | ||
| import org.apache.avro.{LogicalTypes, Schema} | ||
| import org.apache.avro.Conversions.DecimalConversion | ||
| import org.apache.avro.Schema.{Field, Type} | ||
| import org.apache.avro.file.{DataFileReader, DataFileWriter} | ||
| import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord} | ||
|
|
@@ -494,6 +495,104 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | |
| checkAnswer(df, expected) | ||
| } | ||
|
|
||
| test("Logical type: Decimal") { | ||
| val precision = 4 | ||
| val scale = 2 | ||
| val bytesFieldName = "bytes" | ||
| val bytesSchema = s"""{ | ||
| "type":"bytes", | ||
| "logicalType":"decimal", | ||
| "precision":$precision, | ||
| "scale":$scale | ||
| } | ||
| """ | ||
|
|
||
| val fixedFieldName = "fixed" | ||
| val fixedSchema = s"""{ | ||
| "type":"fixed", | ||
| "size":5, | ||
| "logicalType":"decimal", | ||
| "precision":$precision, | ||
| "scale":$scale, | ||
| "name":"foo" | ||
| } | ||
| """ | ||
| val avroSchema = s""" | ||
| { | ||
| "namespace": "logical", | ||
| "type": "record", | ||
| "name": "test", | ||
| "fields": [ | ||
| {"name": "$bytesFieldName", "type": $bytesSchema}, | ||
| {"name": "$fixedFieldName", "type": $fixedSchema} | ||
| ] | ||
| } | ||
| """ | ||
| val schema = new Schema.Parser().parse(avroSchema) | ||
| val datumWriter = new GenericDatumWriter[GenericRecord](schema) | ||
| val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) | ||
| val decimalConversion = new DecimalConversion | ||
| withTempDir { dir => | ||
| val avroFile = s"$dir.avro" | ||
| dataFileWriter.create(schema, new File(avroFile)) | ||
| val logicalType = LogicalTypes.decimal(precision, scale) | ||
| val data = Seq("1.23", "4.56", "78.90", "-1", "-2.31") | ||
| data.map { x => | ||
| val avroRec = new GenericData.Record(schema) | ||
| val decimal = new java.math.BigDecimal(x).setScale(scale) | ||
| val bytes = | ||
| decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType) | ||
| avroRec.put(bytesFieldName, bytes) | ||
| val fixed = | ||
| decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType) | ||
| avroRec.put(fixedFieldName, fixed) | ||
| dataFileWriter.append(avroRec) | ||
| } | ||
| dataFileWriter.flush() | ||
| dataFileWriter.close() | ||
|
|
||
| val expected = data.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) } | ||
| val df = spark.read.format("avro").load(avroFile) | ||
| checkAnswer(df, expected) | ||
| checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile), | ||
| expected) | ||
|
|
||
| withTempPath { path => | ||
| df.write.format("avro").save(path.toString) | ||
| checkAnswer(spark.read.format("avro").load(path.toString), expected) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("Logical type: Decimal with too large precision") { | ||
| withTempDir { dir => | ||
| val schema = new Schema.Parser().parse("""{ | ||
| "namespace": "logical", | ||
| "type": "record", | ||
| "name": "test", | ||
| "fields": [{ | ||
| "name": "decimal", | ||
| "type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2} | ||
| }] | ||
| }""") | ||
| val datumWriter = new GenericDatumWriter[GenericRecord](schema) | ||
| val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) | ||
| dataFileWriter.create(schema, new File(s"$dir.avro")) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's either always use python to write test files, or always use java. |
||
| val avroRec = new GenericData.Record(schema) | ||
| val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678") | ||
| val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38)) | ||
| avroRec.put("decimal", bytes) | ||
| dataFileWriter.append(avroRec) | ||
| dataFileWriter.flush() | ||
| dataFileWriter.close() | ||
|
|
||
| val msg = intercept[SparkException] { | ||
| spark.read.format("avro").load(s"$dir.avro").collect() | ||
| }.getCause.getMessage | ||
| assert(msg.contains("Unscaled value too large for precision")) | ||
| } | ||
| } | ||
|
|
||
| test("Array data types") { | ||
| withTempPath { dir => | ||
| val testSchema = StructType(Seq( | ||
|
|
@@ -689,7 +788,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { | |
|
|
||
| // DecimalType should be converted to string | ||
| val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect() | ||
| assert(decimals.map(_(0)).contains("3.14")) | ||
| assert(decimals.map(_(0)).contains(new java.math.BigDecimal("3.14"))) | ||
|
|
||
| // There should be a null entry | ||
| val length = spark.read.format("avro").load(avroDir).select("Length").collect() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parquet can convert binary to unscaled long directly, shall we follow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing to
binaryToUnscaledLong, I think using the method from Avro library makes more sense.Also the method
binaryToUnscaledLongis using the underlying byte array of parquet Binary without copying it. (If we create a new Util method for both, then Parquet data source will lose this optimization.)For performance consideration, we can create a similar method in Avro. I tried the function
binaryToUnscaledLongin Avro and it works. I can change it if you insist.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok let's leave it. We can always add later.