Skip to content

Commit

Permalink
[SPARK-35139][SQL] Support ANSI intervals as Arrow Column vectors
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
 Support YearMonthIntervalType and DayTimeIntervalType to extend ArrowColumnVector

### Why are the changes needed?
https://issues.apache.org/jira/browse/SPARK-35139

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
1. By checking coding style via:
    $ ./dev/scalastyle
    $ ./dev/lint-java
2. Run the test "ArrowWriterSuite"

Closes #32340 from Peng-Lei/SPARK-35139.

Authored-by: PengLei <18066542445@189.cn>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Peng-Lei authored and cloud-fan committed Apr 27, 2021
1 parent 7f51106 commit eb08b90
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 1 deletion.
Expand Up @@ -19,12 +19,16 @@

import org.apache.arrow.vector.*;
import org.apache.arrow.vector.complex.*;
import org.apache.arrow.vector.holders.NullableIntervalDayHolder;
import org.apache.arrow.vector.holders.NullableVarCharHolder;

import org.apache.spark.sql.util.ArrowUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.UTF8String;

import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_DAY;
import static org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS;

/**
* A column vector backed by Apache Arrow. Currently calendar interval type and map type are not
* supported.
Expand Down Expand Up @@ -172,6 +176,10 @@ public ArrowColumnVector(ValueVector vector) {
}
} else if (vector instanceof NullVector) {
accessor = new NullAccessor((NullVector) vector);
} else if (vector instanceof IntervalYearVector) {
accessor = new IntervalYearAccessor((IntervalYearVector) vector);
} else if (vector instanceof IntervalDayVector) {
accessor = new IntervalDayAccessor((IntervalDayVector) vector);
} else {
throw new UnsupportedOperationException();
}
Expand Down Expand Up @@ -508,4 +516,37 @@ private static class NullAccessor extends ArrowVectorAccessor {
super(vector);
}
}

private static class IntervalYearAccessor extends ArrowVectorAccessor {

private final IntervalYearVector accessor;

IntervalYearAccessor(IntervalYearVector vector) {
super(vector);
this.accessor = vector;
}

@Override
int getInt(int rowId) {
return accessor.get(rowId);
}
}

private static class IntervalDayAccessor extends ArrowVectorAccessor {

private final IntervalDayVector accessor;
private final NullableIntervalDayHolder intervalDayHolder = new NullableIntervalDayHolder();

IntervalDayAccessor(IntervalDayVector vector) {
super(vector);
this.accessor = vector;
}

@Override
long getLong(int rowId) {
accessor.get(rowId, intervalDayHolder);
return Math.addExact(Math.multiplyExact(intervalDayHolder.days, MICROS_PER_DAY),
intervalDayHolder.milliseconds * MICROS_PER_MILLIS);
}
}
}
Expand Up @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit}
import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit}
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}

import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -54,6 +54,8 @@ private[sql] object ArrowUtils {
new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId)
}
case NullType => ArrowType.Null.INSTANCE
case YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH)
case DayTimeIntervalType => new ArrowType.Interval(IntervalUnit.DAY_TIME)
case _ =>
throw new UnsupportedOperationException(s"Unsupported data type: ${dt.catalogString}")
}
Expand All @@ -74,6 +76,8 @@ private[sql] object ArrowUtils {
case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType
case ArrowType.Null.INSTANCE => NullType
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType
case di: ArrowType.Interval if di.getUnit == IntervalUnit.DAY_TIME => DayTimeIntervalType
case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt")
}

Expand Down
Expand Up @@ -48,6 +48,8 @@ class ArrowUtilsSuite extends SparkFunSuite {
roundtrip(BinaryType)
roundtrip(DecimalType.SYSTEM_DEFAULT)
roundtrip(DateType)
roundtrip(YearMonthIntervalType)
roundtrip(DayTimeIntervalType)
val tsExMsg = intercept[UnsupportedOperationException] {
roundtrip(TimestampType)
}
Expand Down
Expand Up @@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.DateTimeConstants.{MICROS_PER_DAY, MICROS_PER_MILLIS}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
Expand Down Expand Up @@ -74,6 +75,8 @@ object ArrowWriter {
}
new StructWriter(vector, children.toArray)
case (NullType, vector: NullVector) => new NullWriter(vector)
case (YearMonthIntervalType, vector: IntervalYearVector) => new IntervalYearWriter(vector)
case (DayTimeIntervalType, vector: IntervalDayVector) => new IntervalDayWriter(vector)
case (dt, _) =>
throw QueryExecutionErrors.unsupportedDataTypeError(dt)
}
Expand Down Expand Up @@ -394,3 +397,28 @@ private[arrow] class NullWriter(val valueVector: NullVector) extends ArrowFieldW
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
}
}

private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector)
extends ArrowFieldWriter {
override def setNull(): Unit = {
valueVector.setNull(count)
}

override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
valueVector.setSafe(count, input.getInt(ordinal));
}
}

private[arrow] class IntervalDayWriter(val valueVector: IntervalDayVector)
extends ArrowFieldWriter {
override def setNull(): Unit = {
valueVector.setNull(count)
}

override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
val totalMicroseconds = input.getLong(ordinal)
val days = totalMicroseconds / MICROS_PER_DAY
val millis = (totalMicroseconds % MICROS_PER_DAY) / MICROS_PER_MILLIS
valueVector.set(count, days.toInt, millis.toInt)
}
}
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.arrow

import org.apache.arrow.vector.IntervalDayVector

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util._
Expand Down Expand Up @@ -54,6 +56,8 @@ class ArrowWriterSuite extends SparkFunSuite {
case BinaryType => reader.getBinary(rowId)
case DateType => reader.getInt(rowId)
case TimestampType => reader.getLong(rowId)
case YearMonthIntervalType => reader.getInt(rowId)
case DayTimeIntervalType => reader.getLong(rowId)
}
assert(value === datum)
}
Expand All @@ -73,6 +77,33 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DateType, Seq(0, 1, 2, null, 4))
check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles")
check(NullType, Seq(null, null, null))
check(YearMonthIntervalType, Seq(null, 0, 1, -1, Int.MaxValue, Int.MinValue))
check(DayTimeIntervalType, Seq(null, 0L, 1000L, -1000L, (Long.MaxValue - 807L),
(Long.MinValue + 808L)))
}

test("long overflow for DayTimeIntervalType")
{
val schema = new StructType().add("value", DayTimeIntervalType, nullable = true)
val writer = ArrowWriter.create(schema, null)
val reader = new ArrowColumnVector(writer.root.getFieldVectors().get(0))
val valueVector = writer.root.getFieldVectors().get(0).asInstanceOf[IntervalDayVector]

valueVector.set(0, 106751992, 0)
valueVector.set(1, 106751991, Int.MaxValue)

// first long overflow for test Math.multiplyExact()
val msg = intercept[java.lang.ArithmeticException] {
reader.getLong(0)
}.getMessage
assert(msg.equals("long overflow"))

// second long overflow for test Math.addExact()
val msg1 = intercept[java.lang.ArithmeticException] {
reader.getLong(1)
}.getMessage
assert(msg1.equals("long overflow"))
writer.root.close()
}

test("get multiple") {
Expand All @@ -97,6 +128,8 @@ class ArrowWriterSuite extends SparkFunSuite {
case DoubleType => reader.getDoubles(0, data.size)
case DateType => reader.getInts(0, data.size)
case TimestampType => reader.getLongs(0, data.size)
case YearMonthIntervalType => reader.getInts(0, data.size)
case DayTimeIntervalType => reader.getLongs(0, data.size)
}
assert(values === data)

Expand All @@ -111,6 +144,8 @@ class ArrowWriterSuite extends SparkFunSuite {
check(DoubleType, (0 until 10).map(_.toDouble))
check(DateType, (0 until 10))
check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles")
check(YearMonthIntervalType, (0 until 10))
check(DayTimeIntervalType, (-10 until 10).map(_ * 1000.toLong))
}

test("array") {
Expand Down

0 comments on commit eb08b90

Please sign in to comment.