Skip to content

Commit

Permalink
[SPARK-22362][SQL] Add unit test for Window Aggregate Functions
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Improving the test coverage of window functions focusing on missing test for window aggregate functions. No new UDAF test is added as it has been tested already.

## How was this patch tested?

Only new tests were added, automated tests were executed.

Author: “attilapiros” <piros.attila.zsolt@gmail.com>
Author: Attila Zsolt Piros <2017933+attilapiros@users.noreply.github.com>

Closes #20046 from attilapiros/SPARK-22362.
  • Loading branch information
attilapiros authored and hvanhovell committed Apr 19, 2018
1 parent a471880 commit 9ea8d3d
Show file tree
Hide file tree
Showing 3 changed files with 294 additions and 12 deletions.
10 changes: 9 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/window.sql
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,15 @@ ntile(2) OVER w AS ntile,
row_number() OVER w AS row_number,
var_pop(val) OVER w AS var_pop,
var_samp(val) OVER w AS var_samp,
approx_count_distinct(val) OVER w AS approx_count_distinct
approx_count_distinct(val) OVER w AS approx_count_distinct,
covar_pop(val, val_long) OVER w AS covar_pop,
corr(val, val_long) OVER w AS corr,
stddev_samp(val) OVER w AS stddev_samp,
stddev_pop(val) OVER w AS stddev_pop,
collect_list(val) OVER w AS collect_list,
collect_set(val) OVER w AS collect_set,
skewness(val_double) OVER w AS skewness,
kurtosis(val_double) OVER w AS kurtosis
FROM testData
WINDOW w AS (PARTITION BY cate ORDER BY val)
ORDER BY cate, val;
Expand Down
30 changes: 19 additions & 11 deletions sql/core/src/test/resources/sql-tests/results/window.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -273,22 +273,30 @@ ntile(2) OVER w AS ntile,
row_number() OVER w AS row_number,
var_pop(val) OVER w AS var_pop,
var_samp(val) OVER w AS var_samp,
approx_count_distinct(val) OVER w AS approx_count_distinct
approx_count_distinct(val) OVER w AS approx_count_distinct,
covar_pop(val, val_long) OVER w AS covar_pop,
corr(val, val_long) OVER w AS corr,
stddev_samp(val) OVER w AS stddev_samp,
stddev_pop(val) OVER w AS stddev_pop,
collect_list(val) OVER w AS collect_list,
collect_set(val) OVER w AS collect_set,
skewness(val_double) OVER w AS skewness,
kurtosis(val_double) OVER w AS kurtosis
FROM testData
WINDOW w AS (PARTITION BY cate ORDER BY val)
ORDER BY cate, val
-- !query 17 schema
struct<val:int,cate:string,max:int,min:int,min:int,count:bigint,sum:bigint,avg:double,stddev:double,first_value:int,first_value_ignore_null:int,first_value_contain_null:int,last_value:int,last_value_ignore_null:int,last_value_contain_null:int,rank:int,dense_rank:int,cume_dist:double,percent_rank:double,ntile:int,row_number:int,var_pop:double,var_samp:double,approx_count_distinct:bigint>
struct<val:int,cate:string,max:int,min:int,min:int,count:bigint,sum:bigint,avg:double,stddev:double,first_value:int,first_value_ignore_null:int,first_value_contain_null:int,last_value:int,last_value_ignore_null:int,last_value_contain_null:int,rank:int,dense_rank:int,cume_dist:double,percent_rank:double,ntile:int,row_number:int,var_pop:double,var_samp:double,approx_count_distinct:bigint,covar_pop:double,corr:double,stddev_samp:double,stddev_pop:double,collect_list:array<int>,collect_set:array<int>,skewness:double,kurtosis:double>
-- !query 17 output
NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0
3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1
NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1
2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2
1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1
2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2
3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3
NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NULL NULL
3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 0.0 NaN NaN 0.0 [3] [3] NaN NaN
NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 NULL NULL NULL NULL [] [] NaN NaN
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 1 2 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5
1 a 1 1 1 2 2 1.0 0.0 NULL 1 NULL 1 1 1 2 2 0.75 0.3333333333333333 2 3 0.0 0.0 1 0.0 NULL 0.0 0.0 [1,1] [1] 0.7071067811865476 -1.5
2 a 2 1 1 3 4 1.3333333333333333 0.5773502691896258 NULL 1 NULL 2 2 2 4 3 1.0 1.0 2 4 0.22222222222222224 0.33333333333333337 2 4.772185885555555E8 1.0 0.5773502691896258 0.4714045207910317 [1,1,2] [1,2] 1.1539890888012805 -0.6672217220327235
1 b 1 1 1 1 1 1.0 NaN 1 1 1 1 1 1 1 1 0.3333333333333333 0.0 1 1 0.0 NaN 1 NULL NULL NaN 0.0 [1] [1] NaN NaN
2 b 2 1 1 2 3 1.5 0.7071067811865476 1 1 1 2 2 2 2 2 0.6666666666666666 0.5 1 2 0.25 0.5 2 0.0 NaN 0.7071067811865476 0.5 [1,2] [1,2] 0.0 -2.0000000000000013
3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 5.3687091175E8 1.0 1.0 0.816496580927726 [1,2,3] [1,2,3] 0.7057890433107311 -1.4999999999999984


-- !query 18
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql

import java.sql.{Date, Timestamp}

import scala.collection.mutable

import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -86,6 +88,236 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
assert(e.message.contains("requires window to be ordered"))
}

test("corr, covar_pop, stddev_pop functions in specific window") {
val df = Seq(
("a", "p1", 10.0, 20.0),
("b", "p1", 20.0, 10.0),
("c", "p2", 20.0, 20.0),
("d", "p2", 20.0, 20.0),
("e", "p3", 0.0, 0.0),
("f", "p3", 6.0, 12.0),
("g", "p3", 6.0, 12.0),
("h", "p3", 8.0, 16.0),
("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
checkAnswer(
df.select(
$"key",
corr("value1", "value2").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
covar_pop("value1", "value2")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
var_pop("value1")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
stddev_pop("value1")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
var_pop("value2")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
stddev_pop("value2")
.over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))),

// As stddev_pop(expr) = sqrt(var_pop(expr))
// the "stddev_pop" column can be calculated from the "var_pop" column.
//
// As corr(expr1, expr2) = covar_pop(expr1, expr2) / (stddev_pop(expr1) * stddev_pop(expr2))
// the "corr" column can be calculated from the "covar_pop" and the two "stddev_pop" columns.
Seq(
Row("a", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
Row("b", -1.0, -25.0, 25.0, 5.0, 25.0, 5.0),
Row("c", null, 0.0, 0.0, 0.0, 0.0, 0.0),
Row("d", null, 0.0, 0.0, 0.0, 0.0, 0.0),
Row("e", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
Row("f", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
Row("g", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
Row("h", 1.0, 18.0, 9.0, 3.0, 36.0, 6.0),
Row("i", Double.NaN, 0.0, 0.0, 0.0, 0.0, 0.0)))
}

test("covar_samp, var_samp (variance), stddev_samp (stddev) functions in specific window") {
val df = Seq(
("a", "p1", 10.0, 20.0),
("b", "p1", 20.0, 10.0),
("c", "p2", 20.0, 20.0),
("d", "p2", 20.0, 20.0),
("e", "p3", 0.0, 0.0),
("f", "p3", 6.0, 12.0),
("g", "p3", 6.0, 12.0),
("h", "p3", 8.0, 16.0),
("i", "p4", 5.0, 5.0)).toDF("key", "partitionId", "value1", "value2")
checkAnswer(
df.select(
$"key",
covar_samp("value1", "value2").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
var_samp("value1").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
variance("value1").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
stddev_samp("value1").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
stddev("value1").over(Window.partitionBy("partitionId")
.orderBy("key").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))
),
Seq(
Row("a", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
Row("b", -50.0, 50.0, 50.0, 7.0710678118654755, 7.0710678118654755),
Row("c", 0.0, 0.0, 0.0, 0.0, 0.0 ),
Row("d", 0.0, 0.0, 0.0, 0.0, 0.0 ),
Row("e", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("f", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("g", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("h", 24.0, 12.0, 12.0, 3.4641016151377544, 3.4641016151377544 ),
Row("i", Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN)))
}

test("collect_list in ascending ordered window") {
val df = Seq(
("a", "p1", "1"),
("b", "p1", "2"),
("c", "p1", "2"),
("d", "p1", null),
("e", "p1", "3"),
("f", "p2", "10"),
("g", "p2", "11"),
("h", "p3", "20"),
("i", "p4", null)).toDF("key", "partition", "value")
checkAnswer(
df.select(
$"key",
sort_array(
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))),
Seq(
Row("a", Array("1", "2", "2", "3")),
Row("b", Array("1", "2", "2", "3")),
Row("c", Array("1", "2", "2", "3")),
Row("d", Array("1", "2", "2", "3")),
Row("e", Array("1", "2", "2", "3")),
Row("f", Array("10", "11")),
Row("g", Array("10", "11")),
Row("h", Array("20")),
Row("i", Array())))
}

test("collect_list in descending ordered window") {
val df = Seq(
("a", "p1", "1"),
("b", "p1", "2"),
("c", "p1", "2"),
("d", "p1", null),
("e", "p1", "3"),
("f", "p2", "10"),
("g", "p2", "11"),
("h", "p3", "20"),
("i", "p4", null)).toDF("key", "partition", "value")
checkAnswer(
df.select(
$"key",
sort_array(
collect_list("value").over(Window.partitionBy($"partition").orderBy($"value".desc)
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))),
Seq(
Row("a", Array("1", "2", "2", "3")),
Row("b", Array("1", "2", "2", "3")),
Row("c", Array("1", "2", "2", "3")),
Row("d", Array("1", "2", "2", "3")),
Row("e", Array("1", "2", "2", "3")),
Row("f", Array("10", "11")),
Row("g", Array("10", "11")),
Row("h", Array("20")),
Row("i", Array())))
}

test("collect_set in window") {
val df = Seq(
("a", "p1", "1"),
("b", "p1", "2"),
("c", "p1", "2"),
("d", "p1", "3"),
("e", "p1", "3"),
("f", "p2", "10"),
("g", "p2", "11"),
("h", "p3", "20")).toDF("key", "partition", "value")
checkAnswer(
df.select(
$"key",
sort_array(
collect_set("value").over(Window.partitionBy($"partition").orderBy($"value")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)))),
Seq(
Row("a", Array("1", "2", "3")),
Row("b", Array("1", "2", "3")),
Row("c", Array("1", "2", "3")),
Row("d", Array("1", "2", "3")),
Row("e", Array("1", "2", "3")),
Row("f", Array("10", "11")),
Row("g", Array("10", "11")),
Row("h", Array("20"))))
}

test("skewness and kurtosis functions in window") {
val df = Seq(
("a", "p1", 1.0),
("b", "p1", 1.0),
("c", "p1", 2.0),
("d", "p1", 2.0),
("e", "p1", 3.0),
("f", "p1", 3.0),
("g", "p1", 3.0),
("h", "p2", 1.0),
("i", "p2", 2.0),
("j", "p2", 5.0)).toDF("key", "partition", "value")
checkAnswer(
df.select(
$"key",
skewness("value").over(Window.partitionBy("partition").orderBy($"key")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)),
kurtosis("value").over(Window.partitionBy("partition").orderBy($"key")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
// results are checked by scipy.stats.skew() and scipy.stats.kurtosis()
Seq(
Row("a", -0.27238010581457267, -1.506920415224914),
Row("b", -0.27238010581457267, -1.506920415224914),
Row("c", -0.27238010581457267, -1.506920415224914),
Row("d", -0.27238010581457267, -1.506920415224914),
Row("e", -0.27238010581457267, -1.506920415224914),
Row("f", -0.27238010581457267, -1.506920415224914),
Row("g", -0.27238010581457267, -1.506920415224914),
Row("h", 0.5280049792181881, -1.5000000000000013),
Row("i", 0.5280049792181881, -1.5000000000000013),
Row("j", 0.5280049792181881, -1.5000000000000013)))
}

test("aggregation function on invalid column") {
val df = Seq((1, "1")).toDF("key", "value")
val e = intercept[AnalysisException](
df.select($"key", count("invalid").over()))
assert(e.message.contains("cannot resolve '`invalid`' given input columns: [key, value]"))
}

test("numerical aggregate functions on string column") {
val df = Seq((1, "a", "b")).toDF("key", "value1", "value2")
checkAnswer(
df.select($"key",
var_pop("value1").over(),
variance("value1").over(),
stddev_pop("value1").over(),
stddev("value1").over(),
sum("value1").over(),
mean("value1").over(),
avg("value1").over(),
corr("value1", "value2").over(),
covar_pop("value1", "value2").over(),
covar_samp("value1", "value2").over(),
skewness("value1").over(),
kurtosis("value1").over()),
Seq(Row(1, null, null, null, null, null, null, null, null, null, null, null, null)))
}

test("statistical functions") {
val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)).
toDF("key", "value")
Expand Down Expand Up @@ -232,6 +464,40 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Row("b", 2, null, null, null, null, null, null)))
}

test("last/first on descending ordered window") {
val nullStr: String = null
val df = Seq(
("a", 0, nullStr),
("a", 1, "x"),
("a", 2, "y"),
("a", 3, "z"),
("a", 4, "v"),
("b", 1, "k"),
("b", 2, "l"),
("b", 3, nullStr)).
toDF("key", "order", "value")
val window = Window.partitionBy($"key").orderBy($"order".desc)
checkAnswer(
df.select(
$"key",
$"order",
first($"value").over(window),
first($"value", ignoreNulls = false).over(window),
first($"value", ignoreNulls = true).over(window),
last($"value").over(window),
last($"value", ignoreNulls = false).over(window),
last($"value", ignoreNulls = true).over(window)),
Seq(
Row("a", 0, "v", "v", "v", null, null, "x"),
Row("a", 1, "v", "v", "v", "x", "x", "x"),
Row("a", 2, "v", "v", "v", "y", "y", "y"),
Row("a", 3, "v", "v", "v", "z", "z", "z"),
Row("a", 4, "v", "v", "v", "v", "v", "v"),
Row("b", 1, null, null, "l", "k", "k", "k"),
Row("b", 2, null, null, "l", "l", "l", "l"),
Row("b", 3, null, null, null, null, null, null)))
}

test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") {
val src = Seq((0, 3, 5)).toDF("a", "b", "c")
.withColumn("Data", struct("a", "b"))
Expand Down

0 comments on commit 9ea8d3d

Please sign in to comment.