Skip to content

Commit

Permalink
[SPARK-35729][SQL][TESTS] Check all day-time interval types in aggreg…
Browse files Browse the repository at this point in the history
…ate expressions

### What changes were proposed in this pull request?

This PR adds test to check `sum` and `avg` works with all the `DayTimeIntervalType`.
This PR also moves a dataframe commonly used by tests `SPARK-34837: Support ANSI SQL intervals by the aggregate function avg` and `SPARK-34716: Support ANSI SQL intervals by the aggregate function sum` to `SQLTestData.scala`, and a little bit modifies it.

### Why are the changes needed?

To ensure the results of aggregations are what is expected.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New test.

Closes #33042 from sarutak/check-interval-agg-dt.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
sarutak authored and MaxGekk committed Jun 23, 2021
1 parent 4761977 commit 2d3fa04
Show file tree
Hide file tree
Showing 2 changed files with 327 additions and 64 deletions.
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.DecimalData
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND}
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}

case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)
Expand Down Expand Up @@ -1116,122 +1117,278 @@ class DataFrameAggregateSuite extends QueryTest
}

test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") {
val df = Seq(
(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)),
(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)),
(2, null, null, null, null),
(3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)),
(3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5)))
.toDF("class", "year-month", "year", "month", "day-time")
.select(
$"class",
$"year-month",
$"year" cast YearMonthIntervalType(YEAR) as "year",
$"month" cast YearMonthIntervalType(MONTH) as "month",
$"day-time")

val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
(Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day-time")

val sumDF = df.select(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time"))
val sumDF = intervalData.select(
sum($"year-month"),
sum($"year"),
sum($"month"),
sum($"day-second"),
sum($"day-minute"),
sum($"day-hour"),
sum($"day"),
sum($"hour-second"),
sum($"hour-minute"),
sum($"hour"),
sum($"minute-second"),
sum($"minute"),
sum($"second"))
checkAnswer(sumDF,
Row(Period.of(2, 5, 0), Period.ofYears(27), Period.of(1, 1, 0), Duration.ofDays(0)))
Row(
Period.of(2, 5, 0),
Period.ofYears(28),
Period.of(1, 1, 0),
Duration.ofDays(9).plusHours(23).plusMinutes(29).plusSeconds(4),
Duration.ofDays(23).plusHours(8).plusMinutes(27),
Duration.ofDays(-8).plusHours(-7),
Duration.ofDays(1),
Duration.ofDays(1).plusHours(12).plusMinutes(2).plusSeconds(33),
Duration.ofMinutes(43),
Duration.ofHours(12),
Duration.ofMinutes(18).plusSeconds(3),
Duration.ofMinutes(52),
Duration.ofSeconds(20)))
assert(find(sumDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(sumDF.schema == StructType(Seq(
StructField("sum(year-month)", YearMonthIntervalType()),
StructField("sum(year)", YearMonthIntervalType(YEAR)),
StructField("sum(month)", YearMonthIntervalType(MONTH)),
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
StructField("sum(day-time)", DayTimeIntervalType()))))
StructField("sum(day-second)", DayTimeIntervalType()),
StructField("sum(day-minute)", DayTimeIntervalType(DAY, MINUTE)),
StructField("sum(day-hour)", DayTimeIntervalType(DAY, HOUR)),
StructField("sum(day)", DayTimeIntervalType(DAY)),
StructField("sum(hour-second)", DayTimeIntervalType(HOUR, SECOND)),
StructField("sum(hour-minute)", DayTimeIntervalType(HOUR, MINUTE)),
StructField("sum(hour)", DayTimeIntervalType(HOUR)),
StructField("sum(minute-second)", DayTimeIntervalType(MINUTE, SECOND)),
StructField("sum(minute)", DayTimeIntervalType(MINUTE)),
StructField("sum(second)", DayTimeIntervalType(SECOND)))))

val sumDF2 =
df.groupBy($"class").agg(sum($"year-month"), sum($"year"), sum($"month"), sum($"day-time"))
intervalData.groupBy($"class").agg(
sum($"year-month"),
sum($"year"),
sum($"month"),
sum($"day-second"),
sum($"day-minute"),
sum($"day-hour"),
sum($"day"),
sum($"hour-second"),
sum($"hour-minute"),
sum($"hour"),
sum($"minute-second"),
sum($"minute"),
sum($"second"))
checkAnswer(sumDF2,
Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) ::
Row(3, Period.of(1, 6, 0), Period.ofYears(18), Period.ofMonths(2), Duration.ofDays(-11)) ::
Row(1,
Period.ofMonths(10),
Period.ofYears(8),
Period.ofMonths(10),
Duration.ofDays(7).plusHours(13).plusMinutes(3).plusSeconds(18),
Duration.ofDays(5).plusHours(21).plusMinutes(12),
Duration.ofDays(1).plusHours(8),
Duration.ofDays(10),
Duration.ofHours(20).plusMinutes(11).plusSeconds(33),
Duration.ofHours(3).plusMinutes(18),
Duration.ofHours(13),
Duration.ofMinutes(2).plusSeconds(59),
Duration.ofMinutes(38),
Duration.ofSeconds(5)) ::
Row(2,
Period.ofMonths(1),
Period.ofYears(1),
Period.ofMonths(1),
Duration.ofSeconds(1),
Duration.ofMinutes(1),
Duration.ofHours(1),
Duration.ofDays(1),
Duration.ofSeconds(1),
Duration.ofMinutes(1),
Duration.ofHours(1),
Duration.ofSeconds(1),
Duration.ofMinutes(1),
Duration.ofSeconds(1)) ::
Row(3,
Period.of(1, 6, 0),
Period.ofYears(19),
Period.ofMonths(2),
Duration.ofDays(2).plusHours(10).plusMinutes(25).plusSeconds(45),
Duration.ofDays(17).plusHours(11).plusMinutes(14),
Duration.ofDays(-9).plusHours(-16),
Duration.ofDays(-10),
Duration.ofHours(15).plusMinutes(50).plusSeconds(59),
Duration.ofHours(-2).plusMinutes(-36),
Duration.ofHours(-2),
Duration.ofMinutes(15).plusSeconds(3),
Duration.ofMinutes(13),
Duration.ofSeconds(14)) ::
Nil)
assert(find(sumDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(sumDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
StructField("sum(year-month)", YearMonthIntervalType()),
StructField("sum(year)", YearMonthIntervalType(YEAR)),
StructField("sum(month)", YearMonthIntervalType(MONTH)),
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
StructField("sum(day-time)", DayTimeIntervalType()))))
StructField("sum(day-second)", DayTimeIntervalType()),
StructField("sum(day-minute)", DayTimeIntervalType(DAY, MINUTE)),
StructField("sum(day-hour)", DayTimeIntervalType(DAY, HOUR)),
StructField("sum(day)", DayTimeIntervalType(DAY)),
StructField("sum(hour-second)", DayTimeIntervalType(HOUR, SECOND)),
StructField("sum(hour-minute)", DayTimeIntervalType(HOUR, MINUTE)),
StructField("sum(hour)", DayTimeIntervalType(HOUR)),
StructField("sum(minute-second)", DayTimeIntervalType(MINUTE, SECOND)),
StructField("sum(minute)", DayTimeIntervalType(MINUTE)),
StructField("sum(second)", DayTimeIntervalType(SECOND)))))

val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
(Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day")
val error = intercept[SparkException] {
checkAnswer(df2.select(sum($"year-month")), Nil)
}
assert(error.toString contains "java.lang.ArithmeticException: integer overflow")

val error2 = intercept[SparkException] {
checkAnswer(df2.select(sum($"day-time")), Nil)
checkAnswer(df2.select(sum($"day")), Nil)
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")
}

test("SPARK-34837: Support ANSI SQL intervals by the aggregate function `avg`") {
val df = Seq(
(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)),
(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)),
(2, null, null, null, null),
(3, Period.ofMonths(-3), Period.ofYears(-12), Period.ofMonths(-3), Duration.ofDays(-6)),
(3, Period.ofMonths(21), Period.ofYears(30), Period.ofMonths(5), Duration.ofDays(-5)),
(3, null, Period.ofYears(1), null, null))
.toDF("class", "year-month", "year", "month", "day-time")
.select(
$"class",
$"year-month",
$"year" cast YearMonthIntervalType(YEAR) as "year",
$"month" cast YearMonthIntervalType(MONTH) as "month",
$"day-time")

val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
(Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day-time")

val avgDF = df.select(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time"))
val avgDF = intervalData.select(
avg($"year-month"),
avg($"year"),
avg($"month"),
avg($"day-second"),
avg($"day-minute"),
avg($"day-hour"),
avg($"day"),
avg($"hour-second"),
avg($"hour-minute"),
avg($"hour"),
avg($"minute-second"),
avg($"minute"),
avg($"second"))
checkAnswer(avgDF,
Row(Period.ofMonths(7), Period.of(5, 7, 0), Period.ofMonths(3), Duration.ofDays(0)))
Row(Period.ofMonths(7),
Period.of(5, 7, 0),
Period.ofMonths(3),
Duration.ofDays(2).plusHours(11).plusMinutes(52).plusSeconds(16),
Duration.ofDays(4).plusHours(16).plusMinutes(5).plusSeconds(24),
Duration.ofDays(-1).plusHours(-15).plusMinutes(-48),
Duration.ofHours(4).plusMinutes(48),
Duration.ofHours(9).plusSeconds(38).plusMillis(250),
Duration.ofMinutes(8).plusSeconds(36),
Duration.ofHours(2).plusMinutes(24),
Duration.ofMinutes(4).plusSeconds(30).plusMillis(750),
Duration.ofMinutes(10).plusSeconds(24),
Duration.ofSeconds(5)))
assert(find(avgDF.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF.schema == StructType(Seq(
StructField("avg(year-month)", YearMonthIntervalType()),
StructField("avg(year)", YearMonthIntervalType()),
StructField("avg(month)", YearMonthIntervalType()),
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
StructField("avg(day-time)", DayTimeIntervalType()))))
StructField("avg(day-second)", DayTimeIntervalType()),
StructField("avg(day-minute)", DayTimeIntervalType()),
StructField("avg(day-hour)", DayTimeIntervalType()),
StructField("avg(day)", DayTimeIntervalType()),
StructField("avg(hour-second)", DayTimeIntervalType()),
StructField("avg(hour-minute)", DayTimeIntervalType()),
StructField("avg(hour)", DayTimeIntervalType()),
StructField("avg(minute-second)", DayTimeIntervalType()),
StructField("avg(minute)", DayTimeIntervalType()),
StructField("avg(second)", DayTimeIntervalType()))))

val avgDF2 =
df.groupBy($"class").agg(avg($"year-month"), avg($"year"), avg($"month"), avg($"day-time"))
intervalData.groupBy($"class").agg(
avg($"year-month"),
avg($"year"),
avg($"month"),
avg($"day-second"),
avg($"day-minute"),
avg($"day-hour"),
avg($"day"),
avg($"hour-second"),
avg($"hour-minute"),
avg($"hour"),
avg($"minute-second"),
avg($"minute"),
avg($"second"))
checkAnswer(avgDF2,
Row(1, Period.ofMonths(10), Period.ofYears(8), Period.ofMonths(10), Duration.ofDays(10)) ::
Row(2, Period.ofMonths(1), Period.ofYears(1), Period.ofMonths(1), Duration.ofDays(1)) ::
Row(3, Period.ofMonths(9), Period.of(6, 4, 0), Period.ofMonths(1),
Duration.ofDays(-5).plusHours(-12)) :: Nil)
Row(1,
Period.ofMonths(10),
Period.ofYears(8),
Period.ofMonths(10),
Duration.ofDays(7).plusHours(13).plusMinutes(3).plusSeconds(18),
Duration.ofDays(5).plusHours(21).plusMinutes(12),
Duration.ofDays(1).plusHours(8),
Duration.ofDays(10),
Duration.ofHours(20).plusMinutes(11).plusSeconds(33),
Duration.ofHours(3).plusMinutes(18),
Duration.ofHours(13),
Duration.ofMinutes(2).plusSeconds(59),
Duration.ofMinutes(38),
Duration.ofSeconds(5)) ::
Row(2,
Period.ofMonths(1),
Period.ofYears(1),
Period.ofMonths(1),
Duration.ofSeconds(1),
Duration.ofMinutes(1),
Duration.ofHours(1),
Duration.ofDays(1),
Duration.ofSeconds(1),
Duration.ofMinutes(1),
Duration.ofHours(1),
Duration.ofSeconds(1),
Duration.ofMinutes(1),
Duration.ofSeconds(1)) ::
Row(3,
Period.ofMonths(9),
Period.of(6, 4, 0),
Period.ofMonths(1),
Duration.ofDays(1).plusHours(5).plusMinutes(12).plusSeconds(52).plusMillis(500),
Duration.ofDays(5).plusHours(19).plusMinutes(44).plusSeconds(40),
Duration.ofDays(-3).plusHours(-5).plusMinutes(-20),
Duration.ofDays(-3).plusHours(-8),
Duration.ofHours(7).plusMinutes(55).plusSeconds(29).plusMillis(500),
Duration.ofMinutes(-52),
Duration.ofMinutes(-40),
Duration.ofMinutes(7).plusSeconds(31).plusMillis(500),
Duration.ofMinutes(4).plusSeconds(20),
Duration.ofSeconds(7)) :: Nil)
assert(find(avgDF2.queryExecution.executedPlan)(_.isInstanceOf[HashAggregateExec]).isDefined)
assert(avgDF2.schema == StructType(Seq(StructField("class", IntegerType, false),
assert(avgDF2.schema == StructType(Seq(
StructField("class", IntegerType, false),
StructField("avg(year-month)", YearMonthIntervalType()),
StructField("avg(year)", YearMonthIntervalType()),
StructField("avg(month)", YearMonthIntervalType()),
// TODO(SPARK-35729): Check all day-time interval types in aggregate expressions
StructField("avg(day-time)", DayTimeIntervalType()))))
StructField("avg(day-second)", DayTimeIntervalType()),
StructField("avg(day-minute)", DayTimeIntervalType()),
StructField("avg(day-hour)", DayTimeIntervalType()),
StructField("avg(day)", DayTimeIntervalType()),
StructField("avg(hour-second)", DayTimeIntervalType()),
StructField("avg(hour-minute)", DayTimeIntervalType()),
StructField("avg(hour)", DayTimeIntervalType()),
StructField("avg(minute-second)", DayTimeIntervalType()),
StructField("avg(minute)", DayTimeIntervalType()),
StructField("avg(second)", DayTimeIntervalType()))))

val df2 = Seq((Period.ofMonths(Int.MaxValue), Duration.ofDays(106751991)),
(Period.ofMonths(10), Duration.ofDays(10)))
.toDF("year-month", "day")
val error = intercept[SparkException] {
checkAnswer(df2.select(avg($"year-month")), Nil)
}
assert(error.toString contains "java.lang.ArithmeticException: integer overflow")

val error2 = intercept[SparkException] {
checkAnswer(df2.select(avg($"day-time")), Nil)
checkAnswer(df2.select(avg($"day")), Nil)
}
assert(error2.toString contains "java.lang.ArithmeticException: long overflow")

val df3 = df.filter($"class" > 4)
val avgDF3 = df3.select(avg($"year-month"), avg($"day-time"))
val df3 = intervalData.filter($"class" > 4)
val avgDF3 = df3.select(avg($"year-month"), avg($"day"))
checkAnswer(avgDF3, Row(null, null) :: Nil)

val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day-time"))
val avgDF4 = df3.groupBy($"class").agg(avg($"year-month"), avg($"day"))
checkAnswer(avgDF4, Nil)
}

Expand Down

0 comments on commit 2d3fa04

Please sign in to comment.