Skip to content

Commit

Permalink
Add skew and kurtosis operations
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrooks-stripe committed Apr 9, 2024
1 parent 1072c6d commit ba667d6
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.yahoo.sketches.{
}

import java.util
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

class Sum[I: Numeric](inputType: DataType) extends SimpleAggregator[I, I, I] {
Expand Down Expand Up @@ -568,3 +569,110 @@ class TopK[T: Ordering: ClassTag](inputType: DataType, k: Int)
extends OrderByLimit[T](inputType, k, Ordering[T].reverse)

class BottomK[T: Ordering: ClassTag](inputType: DataType, k: Int) extends OrderByLimit[T](inputType, k, Ordering[T])

case class MomentsIR(
n: Double,
m1: Double,
m2: Double,
m3: Double,
m4: Double
)

// Uses Welford/Knuth method as the traditional sum of squares based formula has serious numerical stability problems
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Higher-order_statistics
trait MomentAggregator extends SimpleAggregator[Double, MomentsIR, Double] {
override def prepare(input: Double): MomentsIR = {
val ir = MomentsIR(
n = 0,
m1 = 0,
m2 = 0,
m3 = 0,
m4 = 0
)

update(ir, input)
}

override def update(ir: MomentsIR, x: Double): MomentsIR = {
val n1 = ir.n
val n = ir.n + 1
val delta = x - ir.m1
val deltaN = delta / n
val deltaN2 = deltaN * deltaN
val term1 = delta * deltaN * n1
val m1 = ir.m1 + deltaN
val m4 = ir.m4 + term1 * deltaN2 * (n * n - 3 * n + 3) + 6 * deltaN2 * ir.m2 - 4 * deltaN * ir.m3
val m3 = ir.m3 + term1 * deltaN * (n - 2) - 3 * deltaN * ir.m2
val m2 = ir.m2 + term1

MomentsIR(
n = n,
m1 = m1,
m2 = m2,
m3 = m3,
m4 = m4
)
}

override def outputType: DataType = DoubleType

override def irType: DataType = ListType(DoubleType)

override def merge(a: MomentsIR, b: MomentsIR): MomentsIR = {
val n = a.n + b.n
val delta = b.m1 - a.m1
val delta2 = delta * delta
val delta3 = delta * delta2
val delta4 = delta2 * delta2

val m1 = (a.n * a.m1 + b.n * b.m1) / n
val m2 = a.m2 + b.m2 + delta2 * a.n * b.n / n
val m3 = a.m3 + b.m3 + delta3 * a.n * b.n * (a.n - b.n) / (n * n) +
3.0 * delta * (a.n * b.m2 - b.n * a.m2) / n
val m4 = a.m4 + b.m4 + delta4 * a.n * b.n * (a.n * a.n - a.n * b.n + b.n * b.n) / (n * n * n) +
6.0 * delta2 * (a.n * a.n * b.m2 + b.n * b.n * a.m2) / (n * n) + 4.0 * delta * (a.n * b.m3 - b.n * a.m3) / n

MomentsIR(
n = n,
m1 = m1,
m2 = m2,
m3 = m3,
m4 = m4
)
}

override def finalize(ir: MomentsIR): Double

override def clone(ir: MomentsIR): MomentsIR = {
MomentsIR(
n = ir.n,
m1 = ir.m1,
m2 = ir.m2,
m3 = ir.m3,
m4 = ir.m4
)
}

override def normalize(ir: MomentsIR): util.ArrayList[Double] = {
val values = List(ir.n, ir.m1, ir.m2, ir.m3, ir.m4)
new util.ArrayList[Double](values.asJava)
}

override def denormalize(normalized: Any): MomentsIR = {
val values = normalized.asInstanceOf[util.ArrayList[Double]].asScala
MomentsIR(values(0), values(1), values(2), values(3), values(4))
}

override def isDeletable = false
}

class Skew extends MomentAggregator {
override def finalize(ir: MomentsIR): Double =
if (ir.n < 3 || ir.m2 == 0) Double.NaN else Math.sqrt(ir.n) * ir.m3 / Math.pow(ir.m2, 1.5)
}

class Kurtosis extends MomentAggregator {
override def finalize(ir: MomentsIR): Double =
if (ir.n < 4 || ir.m2 == 0) Double.NaN else ir.n * ir.m4 / (ir.m2 * ir.m2) - 3
}
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,26 @@ object ColumnAggregator {
case _ => mismatchException
}

case Operation.SKEW =>
inputType match {
case IntType => simple(new Skew, toDouble[Int])
case LongType => simple(new Skew, toDouble[Long])
case ShortType => simple(new Skew, toDouble[Short])
case DoubleType => simple(new Skew)
case FloatType => simple(new Skew, toDouble[Float])
case _ => mismatchException
}

case Operation.KURTOSIS =>
inputType match {
case IntType => simple(new Kurtosis, toDouble[Int])
case LongType => simple(new Kurtosis, toDouble[Long])
case ShortType => simple(new Kurtosis, toDouble[Short])
case DoubleType => simple(new Kurtosis)
case FloatType => simple(new Kurtosis, toDouble[Float])
case _ => mismatchException
}

case Operation.MIN =>
inputType match {
case IntType => simple(new Min[Int](inputType))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package ai.chronon.aggregator.test

import ai.chronon.aggregator.base._
import junit.framework.TestCase
import org.apache.commons.math3.stat.descriptive.moment.{Kurtosis => ApacheKurtosis, Skewness => ApacheSkew}
import org.junit.Assert._

class MomentTest extends TestCase {
def makeAgg(aggregator: MomentAggregator, values: Seq[Double]): (MomentAggregator, MomentsIR) = {
var ir = aggregator.prepare(values.head)

values.tail.foreach(v => {
ir = aggregator.update(ir, v)
})

(aggregator, ir)
}

def expectedSkew(values: Seq[Double]): Double = new ApacheSkew().evaluate(values.toArray)
def expectedKurtosis(values: Seq[Double]): Double = new ApacheKurtosis().evaluate(values.toArray)

def assertUpdate(aggregator: MomentAggregator, values: Seq[Double], expected: Seq[Double] => Double): Unit = {
val (agg, ir) = makeAgg(aggregator, values)
assertEquals(expected(values), agg.finalize(ir), 0.1)
}

def assertMerge(aggregator: MomentAggregator,
v1: Seq[Double],
v2: Seq[Double],
expected: Seq[Double] => Double): Unit = {
val (agg, ir1) = makeAgg(aggregator, v1)
val (_, ir2) = makeAgg(aggregator, v2)

val ir = agg.merge(ir1, ir2)
assertEquals(expected(v1 ++ v2), agg.finalize(ir), 0.1)
}

def testUpdate(): Unit = {
val values = Seq(1.1, 2.2, 3.3, 4.4, 5.5)
assertUpdate(new Skew(), values, expectedSkew)
assertUpdate(new Kurtosis(), values, expectedKurtosis)
}

def testInsufficientSizes(): Unit = {
val values = Seq(1.1, 2.2, 3.3, 4.4)
assertUpdate(new Skew(), values.take(2), _ => Double.NaN)
assertUpdate(new Kurtosis(), values.take(3), _ => Double.NaN)
}

def testNoVariance(): Unit = {
val values = Seq(1.0, 1.0, 1.0, 1.0)
assertUpdate(new Skew(), values, _ => Double.NaN)
assertUpdate(new Kurtosis(), values, _ => Double.NaN)
}

def testMerge(): Unit = {
val values1 = Seq(1.1, 2.2, 3.3)
val values2 = Seq(4.4, 5.5)
assertMerge(new Kurtosis(), values1, values2, expectedKurtosis)
assertMerge(new Skew(), values1, values2, expectedSkew)
}

def testNormalize(): Unit = {
val values = Seq(1.0, 2.0, 3.0, 4.0, 5.0)
val (agg, ir) = makeAgg(new Kurtosis, values)

val normalized = agg.normalize(ir)
val denormalized = agg.denormalize(normalized)

assertEquals(ir, denormalized)
}
}
2 changes: 2 additions & 0 deletions api/py/ai/chronon/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class Operation:
SUM = ttypes.Operation.SUM
AVERAGE = ttypes.Operation.AVERAGE
VARIANCE = ttypes.Operation.VARIANCE
SKEW = ttypes.Operation.SKEW
KURTOSIS = ttypes.Operation.KURTOSIS
HISTOGRAM = ttypes.Operation.HISTOGRAM
# k truncates the map to top_k most frequent items, 0 turns off truncation
HISTOGRAM_K = collector(ttypes.Operation.HISTOGRAM)
Expand Down
4 changes: 2 additions & 2 deletions api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ enum Operation {
SUM = 7
AVERAGE = 8
VARIANCE = 9
SKEW = 10 // TODO
KURTOSIS = 11 // TODO
SKEW = 10
KURTOSIS = 11
APPROX_PERCENTILE = 12

LAST_K = 13
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ChrononKryoRegistrator extends KryoRegistrator {
"org.apache.spark.sql.types.Metadata",
"ai.chronon.api.Row",
"ai.chronon.spark.KeyWithHash",
"ai.chronon.aggregator.base.MomentsIR",
"ai.chronon.aggregator.windowing.BatchIr",
"ai.chronon.online.RowWrapper",
"ai.chronon.online.Fetcher$Request",
Expand Down
9 changes: 9 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ class FetcherTest extends TestCase {
operation = Operation.AVERAGE,
inputColumn = "rating",
windows = Seq(new Window(1, TimeUnit.DAYS))
),
Builders.Aggregation(
operation = Operation.SKEW,
inputColumn = "rating",
windows = Seq(new Window(1, TimeUnit.DAYS))
)
),
accuracy = Accuracy.TEMPORAL,
Expand Down Expand Up @@ -288,6 +293,10 @@ class FetcherTest extends TestCase {
inputColumn = "rating",
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)),
buckets = Seq("bucket")),
Builders.Aggregation(operation = Operation.SKEW,
inputColumn = "rating",
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)),
buckets = Seq("bucket")),
Builders.Aggregation(operation = Operation.HISTOGRAM,
inputColumn = "txn_types",
windows = Seq(new Window(3, TimeUnit.DAYS))),
Expand Down
39 changes: 39 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -558,4 +558,43 @@ class GroupByTest {
}
assertEquals(0, diff.count())
}

@Test
def testDescriptiveStats(): Unit = {
val (source, endPartition) = createTestSource(suffix = "_descriptive_stats")
val tableUtils = TableUtils(spark)
val namespace = "test_descriptive_stats"
val aggs = Seq(
Builders.Aggregation(
operation = Operation.VARIANCE,
inputColumn = "price",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
)
),
Builders.Aggregation(
operation = Operation.SKEW,
inputColumn = "price",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
)
),
Builders.Aggregation(
operation = Operation.KURTOSIS,
inputColumn = "price",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
)
),
)
backfill(name = "unit_test_group_by_descriptive_stats",
source = source,
endPartition = endPartition,
namespace = namespace,
tableUtils = tableUtils,
additionalAgg = aggs)
}
}

0 comments on commit ba667d6

Please sign in to comment.