Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8945][SQL] Add add and subtract expressions for IntervalType #7398

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.Interval

abstract class UnaryArithmetic extends UnaryExpression {
self: Product =>
Expand Down Expand Up @@ -94,6 +95,8 @@ abstract class BinaryArithmetic extends BinaryOperator {
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
case IntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"""$eval1.doOp($eval2, "$symbol")""")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
Expand All @@ -111,11 +114,17 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
TypeUtils.checkForNumericAndIntervalExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[IntervalType]) {
input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval])
} else {
numeric.plus(input1, input2)
}
}
}

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
Expand All @@ -126,11 +135,17 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

protected def checkTypesInternal(t: DataType) =
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
TypeUtils.checkForNumericAndIntervalExpr(t, "operator " + symbol)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[IntervalType]) {
input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval])
} else {
numeric.minus(input1, input2)
}
}
}

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.codehaus.janino.ClassBodyEvaluator
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types._


// These classes are here to avoid issues with serialization and integration with quasiquotes.
Expand Down Expand Up @@ -57,6 +57,7 @@ class CodeGenContext {
val references: mutable.ArrayBuffer[Expression] = new mutable.ArrayBuffer[Expression]()

val stringType: String = classOf[UTF8String].getName
val intervalType: String = classOf[Interval].getName
val decimalType: String = classOf[Decimal].getName

final val JAVA_BOOLEAN = "boolean"
Expand Down Expand Up @@ -127,6 +128,7 @@ class CodeGenContext {
case dt: DecimalType => decimalType
case BinaryType => "byte[]"
case StringType => stringType
case IntervalType => intervalType
case _: StructType => "InternalRow"
case _: ArrayType => s"scala.collection.Seq"
case _: MapType => s"scala.collection.Map"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types._

object Literal {
def apply(v: Any): Literal = v match {
Expand All @@ -42,6 +42,7 @@ object Literal {
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case a: Array[Byte] => Literal(a, BinaryType)
case i: Interval => Literal(i, IntervalType)
case null => Literal(null, NullType)
case _ =>
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ object TypeUtils {
}
}

def checkForNumericAndIntervalExpr(t: DataType, caller: String): TypeCheckResult = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now #7348 is in, just create a TypeCollection that contains all the numeric types as well as interval type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I will update again.

if (t.isInstanceOf[NumericType] || t.isInstanceOf[IntervalType] || t == NullType) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(s"$caller accepts numeric or interval types, not $t")
}
}

def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = {
if (t.isInstanceOf[IntegralType] || t == NullType) {
TypeCheckResult.TypeCheckSuccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))

assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type")
assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type")
assertError(Add('booleanField, 'booleanField), "operator + accepts numeric or interval types")
assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric or interval types")
assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type")
assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type")
assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type")
Expand Down
13 changes: 13 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1492,4 +1492,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Currently we don't yet support nanosecond
checkIntervalParseError("select interval 23 nanosecond")
}

test("SPARK-8945: add and subtract expressions for interval type") {
import org.apache.spark.unsafe.types.Interval

val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))

checkAnswer(df.select(df("i") + new Interval(2, 123)),
Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123)))

checkAnswer(df.select(df("i") - new Interval(2, 123)),
Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))
}
}
25 changes: 25 additions & 0 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,31 @@ public Interval(int months, long microseconds) {
this.microseconds = microseconds;
}

public Interval doOp(Interval that, String op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think you want this -- it is super slow. Just remove this function, and define codegen in Add / Subtract, rather than relying on what BinaryArithmetic provides.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I wanted to avoid defining codegen in two operations. Update later together.

Interval opRet = null;
switch (op) {
case "+":
opRet = add(that);
break;
case "-":
opRet = subtract(that);
break;
}
return opRet;
}

public Interval add(Interval that) {
int months = this.months + that.months;
long microseconds = this.microseconds + that.microseconds;
return new Interval(months, microseconds);
}

public Interval subtract(Interval that) {
int months = this.months - that.months;
long microseconds = this.microseconds - that.microseconds;
return new Interval(months, microseconds);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,44 @@ public void fromStringTest() {
assertEquals(Interval.fromString(input), null);
}

@Test
public void addTest() {
String input = "interval 3 month 1 hour";
String input2 = "interval 2 month 100 hour";

Interval interval = Interval.fromString(input);
Interval interval2 = Interval.fromString(input2);

assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR));

input = "interval -10 month -81 hour";
input2 = "interval 75 month 200 hour";

interval = Interval.fromString(input);
interval2 = Interval.fromString(input2);

assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR));
}

@Test
public void subtractTest() {
String input = "interval 3 month 1 hour";
String input2 = "interval 2 month 100 hour";

Interval interval = Interval.fromString(input);
Interval interval2 = Interval.fromString(input2);

assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR));

input = "interval -10 month -81 hour";
input2 = "interval 75 month 200 hour";

interval = Interval.fromString(input);
interval2 = Interval.fromString(input2);

assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR));
}

private void testSingleUnit(String unit, int number, int months, long microseconds) {
String input1 = "interval " + number + " " + unit;
String input2 = "interval " + number + " " + unit + "s";
Expand Down