Skip to content

Commit

Permalink
Add COUNT_DISTINCT aggregation (feathr-ai#594)
Browse files Browse the repository at this point in the history
* scala COUNT_DISTINCT aggregation

* cr change comment table to match expected results

* cr - fix comment

* cr - remove showing df
  • Loading branch information
esadler-hbo authored and ahlag committed Aug 26, 2022
1 parent acc1a4d commit 382de8b
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import com.linkedin.feathr.offline.transformation.FeatureColumnFormat.FeatureCol
import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils
import com.linkedin.feathr.offline.util.datetime.{DateTimeInterval, OfflineDateTimeUtils}
import com.linkedin.feathr.swj.{FactData, GroupBySpec, LateralViewParams, SlidingWindowFeature, WindowSpec}
import com.linkedin.feathr.swj.aggregate.{AggregationType, AvgAggregate, AvgPoolingAggregate, CountAggregate, LatestAggregate, MaxAggregate, MaxPoolingAggregate, MinAggregate, MinPoolingAggregate, SumAggregate}
import com.linkedin.feathr.swj.aggregate.{AggregationType, AvgAggregate, AvgPoolingAggregate, CountAggregate, CountDistinctAggregate, LatestAggregate, MaxAggregate, MaxPoolingAggregate, MinAggregate, MinPoolingAggregate, SumAggregate}
import org.apache.log4j.Logger
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.UserDefinedFunction
Expand Down Expand Up @@ -178,6 +178,7 @@ private[offline] object SlidingWindowFeatureUtils {
// In Feathr's use case, we want to treat the count aggregation as simple count of non-null items.
val rewrittenDef = s"CASE WHEN ${featureDef} IS NOT NULL THEN 1 ELSE 0 END"
new CountAggregate(rewrittenDef)
case AggregationType.COUNT_DISTINCT => new CountDistinctAggregate(featureDef)
case AggregationType.AVG => new AvgAggregate(featureDef)
case AggregationType.MAX => new MaxAggregate(featureDef)
case AggregationType.MIN => new MinAggregate(featureDef)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ package com.linkedin.feathr.swj.aggregate

object AggregationType extends Enumeration {
type AggregationType = Value
val SUM, COUNT, AVG, MAX, TIMESINCE, LATEST, DUMMY, MIN, MAX_POOLING, MIN_POOLING, AVG_POOLING, SUM_POOLING = Value
val SUM, COUNT, COUNT_DISTINCT, AVG, MAX, TIMESINCE, LATEST, DUMMY, MIN, MAX_POOLING, MIN_POOLING, AVG_POOLING, SUM_POOLING = Value
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.linkedin.feathr.swj.aggregate

import com.linkedin.feathr.swj.aggregate.AggregationType._
import org.apache.spark.sql.types._

/**
* COUNT_DISTINCT aggregation implementation.
*
* @param metricCol Name of the metric column or a Spark SQL column expression for derived metric
* that will be aggregated using COUNT_DISTINCT.
*/
class CountDistinctAggregate(val metricCol: String) extends AggregationSpec {
override def aggregation: AggregationType = COUNT_DISTINCT

override def metricName = "count_distinct_col"

override def isIncrementalAgg = false

override def isCalculateAggregateNeeded: Boolean = true

override def calculateAggregate(aggregate: Any, dataType: DataType): Any = {
if (aggregate == null) {
aggregate
} else {
dataType match {
case IntegerType => aggregate.asInstanceOf[Set[Int]].size
case LongType => aggregate.asInstanceOf[Set[Long]].size
case DoubleType => aggregate.asInstanceOf[Set[Double]].size
case FloatType => aggregate.asInstanceOf[Set[Float]].size
case StringType => aggregate.asInstanceOf[Set[String]].size
case _ => throw new RuntimeException(s"Invalid data type for COUNT_DISTINCT metric col $metricCol. " +
s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
}
}
}

override def agg(aggregate: Any, record: Any, dataType: DataType): Any = {
if (aggregate == null) {
Set(record)
} else if (record == null) {
aggregate
} else {
dataType match {
case IntegerType => aggregate.asInstanceOf[Set[Int]] + record.asInstanceOf[Int]
case LongType => aggregate.asInstanceOf[Set[Long]] + record.asInstanceOf[Long]
case DoubleType => aggregate.asInstanceOf[Set[Double]] + record.asInstanceOf[Double]
case FloatType => aggregate.asInstanceOf[Set[Float]] + record.asInstanceOf[Float]
case StringType=> aggregate.asInstanceOf[Set[String]] + record.asInstanceOf[String]
case _ => throw new RuntimeException(s"Invalid data type for COUNT_DISTINCT metric col $metricCol. " +
s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
}
}
}

override def deagg(aggregate: Any, record: Any, dataType: DataType): Any = {
throw new RuntimeException("Method deagg for COUNT_DISTINCT aggregate is not implemented because COUNT_DISTINCT is " +
"not an incremental aggregation.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -983,4 +983,86 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {

validateRows(dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField)), expectedRows)
}


@Test
def testSWACountDistinct(): Unit = {
val featureDefAsString =
"""
|sources: {
| swaSource: {
| location: { path: "generation/daily/" }
| isTimeSeries: true
| timeWindowParameters: {
| timestampColumn: "timestamp"
| timestampColumnFormat: "yyyy-MM-dd"
| }
| }
|}
|anchors: {
| swaAnchorWithKeyExtractor: {
| source: "swaSource"
| key: [x]
| features: {
| f: {
| def: "Id" // the column that contains the raw view count
| aggregation: COUNT
| window: 10d
| }
| g: {
| def: "Id" // the column that contains the raw view count
| aggregation: COUNT_DISTINCT
| window: 10d
| }
| }
| }
|}
""".stripMargin

val features = Seq("f", "g")
val keyField = "x"
val featureJoinAsString =
s"""
| settings: {
| joinTimeSettings: {
| timestampColumn: {
| def: timestamp
| format: yyyy-MM-dd
| }
| }
|}
|features: [
| {
| key: [$keyField],
| featureList: [${features.mkString(",")}]
| }
|]
""".stripMargin


/**
* Expected output:
* +--------+----+----+
* |x| f| g|
* +--------+----+----+
* | 1| 6| 2|
* | 2| 5| 2|
* | 3| 1| 1|
* +--------+----+----+
*/
val expectedSchema = StructType(
Seq(
StructField(keyField, LongType),
StructField(features.head, LongType), // f
StructField(features.last, LongType) // g
))

val expectedRows = Array(
new GenericRowWithSchema(Array(1, 6, 2), expectedSchema),
new GenericRowWithSchema(Array(2, 5, 2), expectedSchema),
new GenericRowWithSchema(Array(3, 1, 1), expectedSchema))
val dfs = runLocalFeatureJoinForTest(featureJoinAsString, featureDefAsString, "featuresWithFilterObs.avro.json").data

validateRows(dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField)), expectedRows)
}
}

0 comments on commit 382de8b

Please sign in to comment.