Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8945][SQL] Add add and subtract expressions for IntervalType #7398

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.Interval


case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

override def dataType: DataType = child.dataType

Expand All @@ -36,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}

protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
protected override def nullSafeEval(input: Any): Any = {
if (dataType.isInstanceOf[IntervalType]) {
input.asInstanceOf[Interval].negate()
} else {
numeric.negate(input)
}
}
}

case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def prettyName: String = "positive"

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

override def dataType: DataType = child.dataType

Expand Down Expand Up @@ -95,32 +103,66 @@ private[sql] object BinaryArithmetic {

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

override def symbol: String = "+"
override def decimalMethod: String = "$plus"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[IntervalType]) {
input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval])
} else {
numeric.plus(input1, input2)
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
case IntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

override def symbol: String = "-"
override def decimalMethod: String = "$minus"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[IntervalType]) {
input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval])
} else {
numeric.minus(input1, input2)
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
case IntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.codehaus.janino.ClassBodyEvaluator
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types._


// These classes are here to avoid issues with serialization and integration with quasiquotes.
Expand Down Expand Up @@ -69,6 +69,7 @@ class CodeGenContext {
}

val stringType: String = classOf[UTF8String].getName
val intervalType: String = classOf[Interval].getName
val decimalType: String = classOf[Decimal].getName

final val JAVA_BOOLEAN = "boolean"
Expand Down Expand Up @@ -139,6 +140,7 @@ class CodeGenContext {
case dt: DecimalType => decimalType
case BinaryType => "byte[]"
case StringType => stringType
case IntervalType => intervalType
case _: StructType => "InternalRow"
case _: ArrayType => s"scala.collection.Seq"
case _: MapType => s"scala.collection.Map"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types._

object Literal {
def apply(v: Any): Literal = v match {
Expand All @@ -42,6 +42,7 @@ object Literal {
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case a: Array[Byte] => Literal(a, BinaryType)
case i: Interval => Literal(i, IntervalType)
case null => Literal(null, NullType)
case _ =>
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ private[sql] object TypeCollection {
TimestampType, DateType,
StringType, BinaryType)

/**
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
*/
val NumericAndInterval = TypeCollection(NumericType, IntervalType)

def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)

def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}

test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "expected to be of type numeric")
assertError(UnaryMinus('stringField), "type (numeric or interval)")
assertError(Abs('stringField), "expected to be of type numeric")
assertError(BitwiseNot('stringField), "expected to be of type integral")
}
Expand All @@ -78,8 +78,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))

assertError(Add('booleanField, 'booleanField), "accepts numeric type")
assertError(Subtract('booleanField, 'booleanField), "accepts numeric type")
assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type")
assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type")
assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1492,4 +1492,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Currently we don't yet support nanosecond
checkIntervalParseError("select interval 23 nanosecond")
}

test("SPARK-8945: add and subtract expressions for interval type") {
import org.apache.spark.unsafe.types.Interval

val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))

checkAnswer(df.select(df("i") + new Interval(2, 123)),
Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123)))

checkAnswer(df.select(df("i") - new Interval(2, 123)),
Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))

// unary minus
checkAnswer(df.select(-df("i")),
Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123))))
}
}
16 changes: 16 additions & 0 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ public Interval(int months, long microseconds) {
this.microseconds = microseconds;
}

public Interval add(Interval that) {
int months = this.months + that.months;
long microseconds = this.microseconds + that.microseconds;
return new Interval(months, microseconds);
}

public Interval subtract(Interval that) {
int months = this.months - that.months;
long microseconds = this.microseconds - that.microseconds;
return new Interval(months, microseconds);
}

public Interval negate() {
return new Interval(-this.months, -this.microseconds);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ public void fromStringTest() {
assertEquals(Interval.fromString(input), null);
}

@Test
public void addTest() {
String input = "interval 3 month 1 hour";
String input2 = "interval 2 month 100 hour";

Interval interval = Interval.fromString(input);
Interval interval2 = Interval.fromString(input2);

assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR));

input = "interval -10 month -81 hour";
input2 = "interval 75 month 200 hour";

interval = Interval.fromString(input);
interval2 = Interval.fromString(input2);

assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR));
}

@Test
public void subtractTest() {
String input = "interval 3 month 1 hour";
String input2 = "interval 2 month 100 hour";

Interval interval = Interval.fromString(input);
Interval interval2 = Interval.fromString(input2);

assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR));

input = "interval -10 month -81 hour";
input2 = "interval 75 month 200 hour";

interval = Interval.fromString(input);
interval2 = Interval.fromString(input2);

assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR));
}

private void testSingleUnit(String unit, int number, int months, long microseconds) {
String input1 = "interval " + number + " " + unit;
String input2 = "interval " + number + " " + unit + "s";
Expand Down