Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

package org.apache.spark.sql.avro

import java.math.{BigDecimal}
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._
import org.apache.avro.generic._
Expand All @@ -38,6 +40,8 @@ import org.apache.spark.unsafe.types.UTF8String
* A deserializer to deserialize data in avro format to data in catalyst format.
*/
class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
private lazy val decimalConversions = new DecimalConversion()

private val converter: Any => Any = rootCatalystType match {
// A shortcut for empty schema.
case st: StructType if st.isEmpty =>
Expand Down Expand Up @@ -138,10 +142,21 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
bytes
case b: Array[Byte] => b
case other => throw new RuntimeException(s"$other is not a valid avro binary.")

}
updater.set(ordinal, bytes)

case (FIXED, d: DecimalType) => (updater, ordinal, value) =>
val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType,
LogicalTypes.decimal(d.precision, d.scale))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parquet can convert binary to unscaled long directly, shall we follow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing to binaryToUnscaledLong, I think using the method from Avro library makes more sense.

Also the method binaryToUnscaledLong is using the underlying byte array of parquet Binary without copying it. (If we create a new Util method for both, then Parquet data source will lose this optimization.)

For performance consideration, we can create a similar method in Avro. I tried the function binaryToUnscaledLong in Avro and it works. I can change it if you insist.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok let's leave it. We can always add later.

val decimal = createDecimal(bigDecimal, d.precision, d.scale)
updater.setDecimal(ordinal, decimal)

case (BYTES, d: DecimalType) => (updater, ordinal, value) =>
val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType,
LogicalTypes.decimal(d.precision, d.scale))
val decimal = createDecimal(bigDecimal, d.precision, d.scale)
updater.setDecimal(ordinal, decimal)

case (RECORD, st: StructType) =>
val writeRecord = getRecordWriter(avroType, st, path)
(updater, ordinal, value) =>
Expand Down Expand Up @@ -263,6 +278,17 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
s"Target Catalyst type: $rootCatalystType")
}

// TODO: move the following method in Decimal object on creating Decimal from BigDecimal?
private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
if (precision <= Decimal.MAX_LONG_DIGITS) {
// Constructs a `Decimal` with an unscaled `Long` value if possible.
Decimal(decimal.unscaledValue().longValue(), precision, scale)
} else {
// Otherwise, resorts to an unscaled `BigInteger` instead.
Decimal(decimal, precision, scale)
}
}

private def getRecordWriter(
avroType: Schema,
sqlType: StructType,
Expand Down Expand Up @@ -334,6 +360,7 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value)
def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value)
def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value)
def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value)
}

final class RowUpdater(row: InternalRow) extends CatalystDataUpdater {
Expand All @@ -347,6 +374,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value)
override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value)
override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value)
override def setDecimal(ordinal: Int, value: Decimal): Unit =
row.setDecimal(ordinal, value, value.precision)
}

final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater {
Expand All @@ -360,5 +389,6 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value)
override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value)
override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value)
override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ import java.nio.ByteBuffer

import scala.collection.JavaConverters._

import org.apache.avro.{LogicalTypes, Schema}
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.LogicalTypes.{TimestampMicros, TimestampMillis}
import org.apache.avro.Schema
import org.apache.avro.Schema.Type
import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record}
import org.apache.avro.generic.GenericData.Record
import org.apache.avro.util.Utf8

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -67,6 +69,8 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:

private type Converter = (SpecializedGetters, Int) => Any

private lazy val decimalConversions = new DecimalConversion()

private def newConverter(catalystType: DataType, avroType: Schema): Converter = {
catalystType match {
case NullType =>
Expand All @@ -86,7 +90,11 @@ class AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable:
case DoubleType =>
(getter, ordinal) => getter.getDouble(ordinal)
case d: DecimalType =>
(getter, ordinal) => getter.getDecimal(ordinal, d.precision, d.scale).toString
(getter, ordinal) =>
val decimal = getter.getDecimal(ordinal, d.precision, d.scale)
decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType,
LogicalTypes.decimal(d.precision, d.scale))

case StringType => avroType.getType match {
case Type.ENUM =>
import scala.collection.JavaConverters._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,28 @@
package org.apache.spark.sql.avro

import scala.collection.JavaConverters._
import scala.util.Random

import com.fasterxml.jackson.annotation.ObjectIdGenerators.UUIDGenerator
import org.apache.avro.{LogicalType, LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.LogicalTypes.{Date, TimestampMicros, TimestampMillis}
import org.apache.avro.LogicalTypes.{Date, Decimal, TimestampMicros, TimestampMillis}
import org.apache.avro.Schema.Type._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator
import org.apache.spark.sql.internal.SQLConf.AvroOutputTimestampType
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.Decimal.{maxPrecisionForBytes, minBytesForPrecision}

/**
* This object contains method that are used to convert sparkSQL schemas to avro schemas and vice
* versa.
*/
object SchemaConverters {
private lazy val uuidGenerator = RandomUUIDGenerator(new Random().nextLong())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this is an unused value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I made a PR to clean it up: #34472


private lazy val nullSchema = Schema.create(Schema.Type.NULL)

case class SchemaType(dataType: DataType, nullable: Boolean)

/**
Expand All @@ -44,14 +53,20 @@ object SchemaConverters {
}
case STRING => SchemaType(StringType, nullable = false)
case BOOLEAN => SchemaType(BooleanType, nullable = false)
case BYTES => SchemaType(BinaryType, nullable = false)
case BYTES | FIXED => avroSchema.getLogicalType match {
// For FIXED type, if the precision requires more bytes than fixed size, the logical
// type will be null, which is handled by Avro library.
case d: Decimal => SchemaType(DecimalType(d.getPrecision, d.getScale), nullable = false)
case _ => SchemaType(BinaryType, nullable = false)
}

case DOUBLE => SchemaType(DoubleType, nullable = false)
case FLOAT => SchemaType(FloatType, nullable = false)
case LONG => avroSchema.getLogicalType match {
case _: TimestampMillis | _: TimestampMicros => SchemaType(TimestampType, nullable = false)
case _ => SchemaType(LongType, nullable = false)
}
case FIXED => SchemaType(BinaryType, nullable = false)

case ENUM => SchemaType(StringType, nullable = false)

case RECORD =>
Expand Down Expand Up @@ -114,32 +129,36 @@ object SchemaConverters {
prevNameSpace: String = "",
outputTimestampType: AvroOutputTimestampType.Value = AvroOutputTimestampType.TIMESTAMP_MICROS)
: Schema = {
val builder = if (nullable) {
SchemaBuilder.builder().nullable()
} else {
SchemaBuilder.builder()
}
val builder = SchemaBuilder.builder()

catalystType match {
val schema = catalystType match {
case BooleanType => builder.booleanType()
case ByteType | ShortType | IntegerType => builder.intType()
case LongType => builder.longType()
case DateType => builder
.intBuilder()
.prop(LogicalType.LOGICAL_TYPE_PROP, LogicalTypes.date().getName)
.endInt()
case DateType =>
LogicalTypes.date().addToSchema(builder.intType())
case TimestampType =>
val timestampType = outputTimestampType match {
case AvroOutputTimestampType.TIMESTAMP_MILLIS => LogicalTypes.timestampMillis()
case AvroOutputTimestampType.TIMESTAMP_MICROS => LogicalTypes.timestampMicros()
case other =>
throw new IncompatibleSchemaException(s"Unexpected output timestamp type $other.")
}
builder.longBuilder().prop(LogicalType.LOGICAL_TYPE_PROP, timestampType.getName).endLong()
timestampType.addToSchema(builder.longType())

case FloatType => builder.floatType()
case DoubleType => builder.doubleType()
case _: DecimalType | StringType => builder.stringType()
case StringType => builder.stringType()
case d: DecimalType =>
val avroType = LogicalTypes.decimal(d.precision, d.scale)
val fixedSize = minBytesForPrecision(d.precision)
// Need to avoid naming conflict for the fixed fields
val name = prevNameSpace match {
case "" => s"$recordName.fixed"
case _ => s"$prevNameSpace.$recordName.fixed"
}
avroType.addToSchema(SchemaBuilder.fixed(name).size(fixedSize))

case BinaryType => builder.bytesType()
case ArrayType(et, containsNull) =>
builder.array()
Expand All @@ -164,6 +183,11 @@ object SchemaConverters {
// This should never happen.
case other => throw new IncompatibleSchemaException(s"Unexpected type $other.")
}
if (nullable) {
Schema.createUnion(schema, nullSchema)
} else {
schema
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite with ExpressionEvalH
BinaryType)

protected def prepareExpectedResult(expected: Any): Any = expected match {
// Spark decimal is converted to avro string=
case d: Decimal => UTF8String.fromString(d.toString)
// Spark byte and short both map to avro int
case b: Byte => b.toInt
case s: Short => s.toInt
Expand Down
103 changes: 101 additions & 2 deletions external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import java.util.{TimeZone, UUID}

import scala.collection.JavaConverters._

import org.apache.avro.Schema
import org.apache.avro.{LogicalTypes, Schema}
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.Schema.{Field, Type}
import org.apache.avro.file.{DataFileReader, DataFileWriter}
import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericDatumWriter, GenericRecord}
Expand Down Expand Up @@ -494,6 +495,104 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
checkAnswer(df, expected)
}

test("Logical type: Decimal") {
val precision = 4
val scale = 2
val bytesFieldName = "bytes"
val bytesSchema = s"""{
"type":"bytes",
"logicalType":"decimal",
"precision":$precision,
"scale":$scale
}
"""

val fixedFieldName = "fixed"
val fixedSchema = s"""{
"type":"fixed",
"size":5,
"logicalType":"decimal",
"precision":$precision,
"scale":$scale,
"name":"foo"
}
"""
val avroSchema = s"""
{
"namespace": "logical",
"type": "record",
"name": "test",
"fields": [
{"name": "$bytesFieldName", "type": $bytesSchema},
{"name": "$fixedFieldName", "type": $fixedSchema}
]
}
"""
val schema = new Schema.Parser().parse(avroSchema)
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
val decimalConversion = new DecimalConversion
withTempDir { dir =>
val avroFile = s"$dir.avro"
dataFileWriter.create(schema, new File(avroFile))
val logicalType = LogicalTypes.decimal(precision, scale)
val data = Seq("1.23", "4.56", "78.90", "-1", "-2.31")
data.map { x =>
val avroRec = new GenericData.Record(schema)
val decimal = new java.math.BigDecimal(x).setScale(scale)
val bytes =
decimalConversion.toBytes(decimal, schema.getField(bytesFieldName).schema, logicalType)
avroRec.put(bytesFieldName, bytes)
val fixed =
decimalConversion.toFixed(decimal, schema.getField(fixedFieldName).schema, logicalType)
avroRec.put(fixedFieldName, fixed)
dataFileWriter.append(avroRec)
}
dataFileWriter.flush()
dataFileWriter.close()

val expected = data.map { x => Row(new java.math.BigDecimal(x), new java.math.BigDecimal(x)) }
val df = spark.read.format("avro").load(avroFile)
checkAnswer(df, expected)
checkAnswer(spark.read.format("avro").option("avroSchema", avroSchema).load(avroFile),
expected)

withTempPath { path =>
df.write.format("avro").save(path.toString)
checkAnswer(spark.read.format("avro").load(path.toString), expected)
}
}
}

test("Logical type: Decimal with too large precision") {
withTempDir { dir =>
val schema = new Schema.Parser().parse("""{
"namespace": "logical",
"type": "record",
"name": "test",
"fields": [{
"name": "decimal",
"type": {"type": "bytes", "logicalType": "decimal", "precision": 4, "scale": 2}
}]
}""")
val datumWriter = new GenericDatumWriter[GenericRecord](schema)
val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter)
dataFileWriter.create(schema, new File(s"$dir.avro"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's either always use python to write test files, or always use java.

val avroRec = new GenericData.Record(schema)
val decimal = new java.math.BigDecimal("0.12345678901234567890123456789012345678")
val bytes = (new DecimalConversion).toBytes(decimal, schema, LogicalTypes.decimal(39, 38))
avroRec.put("decimal", bytes)
dataFileWriter.append(avroRec)
dataFileWriter.flush()
dataFileWriter.close()

val msg = intercept[SparkException] {
spark.read.format("avro").load(s"$dir.avro").collect()
}.getCause.getMessage
assert(msg.contains("Unscaled value too large for precision"))
}
}

test("Array data types") {
withTempPath { dir =>
val testSchema = StructType(Seq(
Expand Down Expand Up @@ -689,7 +788,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {

// DecimalType should be converted to string
val decimals = spark.read.format("avro").load(avroDir).select("Decimal").collect()
assert(decimals.map(_(0)).contains("3.14"))
assert(decimals.map(_(0)).contains(new java.math.BigDecimal("3.14")))

// There should be a null entry
val length = spark.read.format("avro").load(avroDir).select("Length").collect()
Expand Down
Loading