From ad4698629f005b113a0f02c2f8a1faa32a8f8aaa Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 30 Jul 2015 19:04:31 +0800 Subject: [PATCH] Support IntervalType for Parquet. --- .../spark/sql/catalyst/ScalaReflection.scala | 4 +- .../org/apache/spark/sql/types/DataType.scala | 4 +- .../spark/sql/execution/datasources/ddl.scala | 3 +- .../sql/parquet/CatalystRowConverter.scala | 40 ++++++++++++++++++- .../sql/parquet/CatalystSchemaConverter.scala | 5 ++- .../sql/parquet/ParquetTableSupport.scala | 18 ++++++++- .../spark/sql/parquet/ParquetIOSuite.scala | 16 ++++++++ .../sql/parquet/ParquetSchemaSuite.scala | 8 ++++ 8 files changed, 91 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 2442341da106d..42472e4f06786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -135,6 +135,7 @@ trait ScalaReflection { case t if t <:< localTypeOf[java.math.BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if t <:< localTypeOf[CalendarInterval] => Schema(CalendarIntervalType, nullable = true) case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true) @@ -169,6 +170,7 @@ trait ScalaReflection { case obj: java.sql.Date => DateType case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT case obj: Decimal => DecimalType.SYSTEM_DEFAULT + case obj: CalendarInterval => CalendarIntervalType case obj: java.sql.Timestamp => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index f4428c2e8b202..9251192dbf601 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -98,7 +98,8 @@ object DataType { private val nonDecimalNameToType = { Seq(NullType, DateType, TimestampType, BinaryType, - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType, + CalendarIntervalType) .map(t => t.typeName -> t).toMap } @@ -189,6 +190,7 @@ object DataType { | "DecimalType()" ^^^ DecimalType.USER_DEFAULT | fixedDecimalType | "TimestampType" ^^^ TimestampType + | "CalendarIntervalType" ^^^ CalendarIntervalType ) protected lazy val fixedDecimalType: Parser[DataType] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 0cdb407ad57b9..37a024ac75889 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -308,7 +308,8 @@ private[sql] object ResolvedDataSource { mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { - if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { + if (provider != "parquet" && + data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } val clazz: Class[_] = lookupDataSource(provider) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index e00bd90edb3dd..175adb544456b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * A [[ParentContainerUpdater]] is used by a Parquet converter to set converted values to some @@ -140,6 +140,9 @@ private[parquet] class CatalystRowConverter( updater.setShort(value.asInstanceOf[ShortType#InternalType]) } + case CalendarIntervalType => + new CatalystCalendarIntervalConverter(updater) + case t: DecimalType => new CatalystDecimalConverter(t, updater) @@ -236,6 +239,41 @@ private[parquet] class CatalystRowConverter( } } + /** + * Parquet converter for CalendarInterval. + */ + private final class CatalystCalendarIntervalConverter(updater: ParentContainerUpdater) + extends PrimitiveConverter { + + // Converts CalendarInterval stored as FIXED_LENGTH_BYTE_ARRAY + override def addBinary(value: Binary): Unit = { + updater.set(toCalendarInterval(value)) + } + + private def toCalendarInterval(value: Binary): CalendarInterval = { + val bytes = value.getBytes + + var months: Int = 0 + var i = 11 + while (i >=8) { + months = (months << 8) | (bytes(i) & 0xff) + i -= 1 + } + var days: Int = 0 + while (i >= 4) { + days = (days << 8) | (bytes(i) & 0xff) + i -= 1 + } + var milliseconds: Int = 0 + while (i >= 0) { + milliseconds = (milliseconds << 8) | (bytes(i) & 0xff) + i -= 1 + } + new CalendarInterval(months, days * CalendarInterval.MICROS_PER_DAY + + milliseconds * CalendarInterval.MICROS_PER_MILLI) + } + } + /** * Parquet converter for fixed-precision decimals. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index d43ca95b4eea0..09fa71eb7e968 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -187,7 +187,7 @@ private[parquet] class CatalystSchemaConverter( case FIXED_LEN_BYTE_ARRAY => originalType match { case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength)) - case INTERVAL => typeNotImplemented() + case INTERVAL => CalendarIntervalType case _ => illegalType() } @@ -358,6 +358,9 @@ private[parquet] class CatalystSchemaConverter( case DateType => Types.primitive(INT32, repetition).as(DATE).named(field.name) + case CalendarIntervalType => + Types.primitive(FIXED_LEN_BYTE_ARRAY, repetition).as(INTERVAL).length(12).named(field.name) + // NOTE: Spark SQL TimestampType is NOT a well defined type in Parquet format spec. // // As stated in PARQUET-323, Parquet `INT96` was originally introduced to represent nanosecond diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 78ecfad1d57c6..d1a423556258d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * A `parquet.hadoop.api.WriteSupport` for Row objects. @@ -95,7 +95,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo case t @ StructType(_) => writeStruct( t, value.asInstanceOf[CatalystConverter.StructScalaType[_]]) - case _ => writePrimitive(schema.asInstanceOf[AtomicType], value) + case _ => writePrimitive(schema, value) } } } @@ -117,6 +117,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case DecimalType.Fixed(precision, _) => writeDecimal(value.asInstanceOf[Decimal], precision) + case CalendarIntervalType => + writeCalendarInterval(value.asInstanceOf[CalendarInterval]) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -200,6 +202,16 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo // Scratch array used to write decimals as fixed-length byte array private[this] var reusableDecimalBytes = new Array[Byte](16) + private[parquet] def writeCalendarInterval(ci: CalendarInterval): Unit = { + val days: Int = (ci.microseconds / CalendarInterval.MICROS_PER_DAY).toInt + val rest: Long = ci.microseconds % CalendarInterval.MICROS_PER_DAY + val milliseconds: Int = (rest / CalendarInterval.MICROS_PER_MILLI).toInt + + val buffer = ByteBuffer.allocate(12).order(ByteOrder.LITTLE_ENDIAN) + buffer.putInt(ci.months).putInt(days).putInt(milliseconds) + writer.addBinary(Binary.fromByteArray(buffer.array())) + } + private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision) @@ -295,6 +307,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { writer.addBinary(Binary.fromByteArray(record.getBinary(index))) case DecimalType.Fixed(precision, _) => writeDecimal(record.getDecimal(index), precision) + case CalendarIntervalType => + writeCalendarInterval(record.getInterval(index)) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index b415da5b8c136..d3a89908b8963 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -115,6 +116,21 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } } + test("interval") { + def makeIntervalRDD(interval: CalendarIntervalType): DataFrame = + sqlContext.sparkContext + .parallelize(1 to 10) + .map(i => Tuple1(new CalendarInterval(i, i * 1000L))) + .toDF() + .select($"_1") + + withTempPath { dir => + val data = makeIntervalRDD(CalendarIntervalType) + data.write.parquet(dir.getCanonicalPath) + checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) + } + } + test("date type") { def makeDateRDD(): DataFrame = sqlContext.sparkContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 4a0b3b60f419d..0642002e785bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -913,4 +913,12 @@ class ParquetSchemaSuite extends ParquetSchemaTest { | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); |} """.stripMargin) + + testSchema( + "CalendarInterval", + StructType(Seq(StructField("f1", CalendarIntervalType))), + """message root { + | optional fixed_len_byte_array(12) f1 (INTERVAL); + |} + """.stripMargin) }