Skip to content

Commit

Permalink
[SPARK-23907] Removes regr_ functions in functions.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed May 11, 2018
1 parent 92f6f52 commit ce2c305
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 239 deletions.
171 changes: 0 additions & 171 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -811,177 +811,6 @@ object functions {
*/
def var_pop(columnName: String): Column = var_pop(Column(columnName))

/**
* Aggregate function: returns the number of non-null pairs.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_count(y: Column, x: Column): Column = withAggregateFunction {
RegrCount(y.expr, x.expr)
}

/**
* Aggregate function: returns the number of non-null pairs.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_count(y: String, x: String): Column = regr_count(Column(y), Column(x))

/**
* Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_sxx(y: Column, x: Column): Column = withAggregateFunction {
RegrSXX(y.expr, x.expr)
}

/**
* Aggregate function: returns SUM(x*x)-SUM(x)*SUM(x)/N. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_sxx(y: String, x: String): Column = regr_sxx(Column(y), Column(x))

/**
* Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_syy(y: Column, x: Column): Column = withAggregateFunction {
RegrSYY(y.expr, x.expr)
}

/**
* Aggregate function: returns SUM(y*y)-SUM(y)*SUM(y)/N. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_syy(y: String, x: String): Column = regr_syy(Column(y), Column(x))

/**
* Aggregate function: returns the average of y. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_avgy(y: Column, x: Column): Column = withAggregateFunction {
RegrAvgY(y.expr, x.expr)
}

/**
* Aggregate function: returns the average of y. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_avgy(y: String, x: String): Column = regr_avgy(Column(y), Column(x))

/**
* Aggregate function: returns the average of x. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_avgx(y: Column, x: Column): Column = withAggregateFunction {
RegrAvgX(y.expr, x.expr)
}

/**
* Aggregate function: returns the average of x. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_avgx(y: String, x: String): Column = regr_avgx(Column(y), Column(x))

/**
* Aggregate function: returns the covariance of y and x multiplied for the number of items in
* the dataset. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_sxy(y: Column, x: Column): Column = withAggregateFunction {
RegrSXY(y.expr, x.expr)
}

/**
* Aggregate function: returns the covariance of y and x multiplied for the number of items in
* the dataset. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_sxy(y: String, x: String): Column = regr_sxy(Column(y), Column(x))

/**
* Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is
* ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_slope(y: Column, x: Column): Column = withAggregateFunction {
RegrSlope(y.expr, x.expr)
}

/**
* Aggregate function: returns the slope of the linear regression line. Any pair with a NULL is
* ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_slope(y: String, x: String): Column = regr_slope(Column(y), Column(x))

/**
* Aggregate function: returns the coefficient of determination (also called R-squared or
* goodness of fit) for the regression line. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_r2(y: Column, x: Column): Column = withAggregateFunction {
RegrR2(y.expr, x.expr)
}

/**
* Aggregate function: returns the coefficient of determination (also called R-squared or
* goodness of fit) for the regression line. Any pair with a NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_r2(y: String, x: String): Column = regr_r2(Column(y), Column(x))

/**
* Aggregate function: returns the y-intercept of the linear regression line. Any pair with a
* NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_intercept(y: Column, x: Column): Column = withAggregateFunction {
RegrIntercept(y.expr, x.expr)
}

/**
* Aggregate function: returns the y-intercept of the linear regression line. Any pair with a
* NULL is ignored.
*
* @group agg_funcs
* @since 2.4.0
*/
def regr_intercept(y: String, x: String): Column = regr_intercept(Column(y), Column(x))



//////////////////////////////////////////////////////////////////////////////////////////////
// Window functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -687,72 +687,4 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}
}
}

test("SPARK-23907: regression functions") {
val emptyTableData = Seq.empty[(Double, Double)].toDF("a", "b")
val correlatedData = Seq[(Double, Double)]((2, 3), (3, 4), (7.5, 8.2), (10.3, 12))
.toDF("a", "b")
val correlatedDataWithNull = Seq[(java.lang.Double, java.lang.Double)](
(2.0, 3.0), (3.0, null), (7.5, 8.2), (10.3, 12.0)).toDF("a", "b")
checkAnswer(testData2.groupBy().agg(regr_count("a", "b")), Seq(Row(6)))
checkAnswer(testData3.groupBy().agg(regr_count("a", "b")), Seq(Row(1)))
checkAnswer(emptyTableData.groupBy().agg(regr_count("a", "b")), Seq(Row(0)))

checkAggregatesWithTol(testData2.groupBy().agg(regr_sxx("a", "b")), Row(1.5), absTol)
checkAggregatesWithTol(testData3.groupBy().agg(regr_sxx("a", "b")), Row(0.0), absTol)
checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxx("a", "b")), Row(null), absTol)
checkAggregatesWithTol(testData2.groupBy().agg(regr_syy("b", "a")), Row(1.5), absTol)
checkAggregatesWithTol(testData3.groupBy().agg(regr_syy("b", "a")), Row(0.0), absTol)
checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_syy("b", "a")), Row(null), absTol)

checkAggregatesWithTol(testData2.groupBy().agg(regr_avgx("a", "b")), Row(1.5), absTol)
checkAggregatesWithTol(testData3.groupBy().agg(regr_avgx("a", "b")), Row(2.0), absTol)
checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgx("a", "b")), Row(null), absTol)
checkAggregatesWithTol(testData2.groupBy().agg(regr_avgy("b", "a")), Row(1.5), absTol)
checkAggregatesWithTol(testData3.groupBy().agg(regr_avgy("b", "a")), Row(2.0), absTol)
checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_avgy("b", "a")), Row(null), absTol)

checkAggregatesWithTol(testData2.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol)
checkAggregatesWithTol(testData3.groupBy().agg(regr_sxy("a", "b")), Row(0.0), absTol)
checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_sxy("a", "b")), Row(null), absTol)

checkAggregatesWithTol(testData2.groupBy().agg(regr_slope("a", "b")), Row(0.0), absTol)
checkAggregatesWithTol(testData3.groupBy().agg(regr_slope("a", "b")), Row(null), absTol)
checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_slope("a", "b")), Row(null), absTol)

checkAggregatesWithTol(testData2.groupBy().agg(regr_r2("a", "b")), Row(0.0), absTol)
checkAggregatesWithTol(testData3.groupBy().agg(regr_r2("a", "b")), Row(null), absTol)
checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_r2("a", "b")), Row(null), absTol)

checkAggregatesWithTol(testData2.groupBy().agg(regr_intercept("a", "b")), Row(2.0), absTol)
checkAggregatesWithTol(testData3.groupBy().agg(regr_intercept("a", "b")), Row(null), absTol)
checkAggregatesWithTol(emptyTableData.groupBy().agg(regr_intercept("a", "b")),
Row(null), absTol)


checkAggregatesWithTol(correlatedData.groupBy().agg(
regr_count("a", "b"),
regr_avgx("a", "b"),
regr_avgy("a", "b"),
regr_sxx("a", "b"),
regr_syy("a", "b"),
regr_sxy("a", "b"),
regr_slope("a", "b"),
regr_r2("a", "b"),
regr_intercept("a", "b")),
Row(4, 6.8, 5.7, 51.28, 45.38, 48.06, 0.937207488, 0.992556013, -0.67301092),
absTol)
checkAggregatesWithTol(correlatedDataWithNull.groupBy().agg(
regr_count("a", "b"),
regr_avgx("a", "b"),
regr_avgy("a", "b"),
regr_sxx("a", "b"),
regr_syy("a", "b"),
regr_sxy("a", "b"),
regr_slope("a", "b"),
regr_r2("a", "b"),
regr_intercept("a", "b")),
Row(3, 7.73333333, 6.6, 40.82666666, 35.66, 37.98, 0.93027433, 0.99079694, -0.59412149),
absTol)
}
}

0 comments on commit ce2c305

Please sign in to comment.