Skip to content

Commit

Permalink
Add document. Return NaN when count is zero.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Oct 27, 2015
1 parent e1fb438 commit 02562f3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,16 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
override val evaluateExpression = Cast(currentSum, resultType)
}

/**
* Compute Pearson correlation between two expressions.
* When applied on empty data (i.e., count is zero), it returns NaN.
*
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*
* @param left one of the expressions to compute correlation with.
* @param right another expression to compute correlation with.
*/
case class Corr(
left: Expression,
right: Expression,
Expand Down Expand Up @@ -591,6 +601,8 @@ case class Corr(
buffer.setLong(mutableAggBufferOffset + 5, count)
}

// Merge counters from other partitions. Formula can be found at:
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = {
val count2 = buffer2.getLong(inputAggBufferOffset + 5)

Expand Down Expand Up @@ -628,10 +640,15 @@ case class Corr(
}

override def eval(buffer: InternalRow): Any = {
val Ck = buffer.getDouble(mutableAggBufferOffset + 2)
val MkX = buffer.getDouble(mutableAggBufferOffset + 3)
val MkY = buffer.getDouble(mutableAggBufferOffset + 4)
Ck / math.sqrt(MkX * MkY)
val count = buffer.getLong(mutableAggBufferOffset + 5)
if (count > 0) {
val Ck = buffer.getDouble(mutableAggBufferOffset + 2)
val MkX = buffer.getDouble(mutableAggBufferOffset + 3)
val MkY = buffer.getDouble(mutableAggBufferOffset + 4)
Ck / math.sqrt(MkX * MkY)
} else {
Double.NaN
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
val df2 = Seq.tabulate(20)(x => (1.0 * x, x * x - 2 * x + 3.5)).toDF("a", "b")
val corr3 = df2.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)

val df3 = Seq.tabulate(0)(i => (1.0 * i, 2.0 * i)).toDF("a", "b")
val corr4 = df3.groupBy().agg(corr("a", "b")).collect()(0).getDouble(0)
assert(corr4.isNaN)
}

test("test Last implemented based on AggregateExpression1") {
Expand Down

0 comments on commit 02562f3

Please sign in to comment.