Skip to content

Commit

Permalink
Feature: Add Row Level Result Treatment Options for Miminum and Maxim…
Browse files Browse the repository at this point in the history
…um (#535)

* Address comments on PR #532

* Add filtered row-level result support for Minimum, Maximum, Compliance, PatternMatch, MinLength, MaxLength analyzers

* Refactored criterion for MinLength and MaxLength analyzers to separate rowLevelResults logic
  • Loading branch information
eycho-am authored and rdsharma26 committed Apr 16, 2024
1 parent 74d7edb commit 4830640
Show file tree
Hide file tree
Showing 27 changed files with 924 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.amazon.deequ.repository._
import org.apache.spark.sql.{DataFrame, SparkSession}

/** A class to build a VerificationRun using a fluent API */
class VerificationRunBuilder(val data: DataFrame) {
class VerificationRunBuilder(val data: DataFrame) {

protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty

Expand Down
32 changes: 24 additions & 8 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.runners._
import com.amazon.deequ.metrics.DoubleMetric
Expand Down Expand Up @@ -172,6 +172,12 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = {
source.load[S](this).foreach { state => target.persist(this, state) }
}

private[deequ] def getRowLevelFilterTreatment(analyzerOptions: Option[AnalyzerOptions]): FilteredRowOutcome = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRowOutcome.TRUE)
}
}

/** An analyzer that runs a set of aggregation functions over the data,
Expand Down Expand Up @@ -257,15 +263,19 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo
}

case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore,
filteredRow: FilteredRow = FilteredRow.TRUE)
filteredRow: FilteredRowOutcome = FilteredRowOutcome.TRUE)
object NullBehavior extends Enumeration {
type NullBehavior = Value
val Ignore, EmptyString, Fail = Value
}

object FilteredRow extends Enumeration {
type FilteredRow = Value
object FilteredRowOutcome extends Enumeration {
type FilteredRowOutcome = Value
val NULL, TRUE = Value

implicit class FilteredRowOutcomeOps(value: FilteredRowOutcome) {
def getExpression: Column = expr(value.toString)
}
}

/** Base class for analyzers that compute ratios of matching predicates */
Expand Down Expand Up @@ -484,6 +494,12 @@ private[deequ] object Analyzers {
.getOrElse(selection)
}

def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Boolean): Column = {
where
.map { condition => when(condition, replaceWith).otherwise(selection) }
.getOrElse(selection)
}

def conditionalSelection(selection: Column, where: Option[String], replaceWith: Double): Column = {
conditionSelectionGivenColumn(selection, where.map(expr), replaceWith)
}
Expand All @@ -500,12 +516,12 @@ private[deequ] object Analyzers {
def conditionalSelectionFilteredFromColumns(
selection: Column,
conditionColumn: Option[Column],
filterTreatment: String)
filterTreatment: FilteredRowOutcome)
: Column = {
conditionColumn
.map { condition => {
when(not(condition), expr(filterTreatment)).when(condition, selection)
} }
.map { condition =>
when(not(condition), filterTreatment.getExpression).when(condition, selection)
}
.getOrElse(selection)
}

Expand Down
11 changes: 2 additions & 9 deletions src/main/scala/com/amazon/deequ/analyzers/Completeness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested}
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{IntegerType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
Expand Down Expand Up @@ -54,15 +53,9 @@ case class Completeness(column: String, where: Option[String] = None,
@VisibleForTesting // required by some tests that compare analyzer results to an expected state
private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull

@VisibleForTesting
private[deequ] def rowLevelResults: Column = {
val whereCondition = where.map { expression => expr(expression)}
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString)
}

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
conditionalSelectionFilteredFromColumns(
col(column).isNotNull, whereCondition, getRowLevelFilterTreatment(analyzerOptions))
}
}
19 changes: 17 additions & 2 deletions src/main/scala/com/amazon/deequ/analyzers/Compliance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.functions._
import Analyzers._
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.types.DoubleType

/**
* Compliance is a measure of the fraction of rows that complies with the given column constraint.
Expand All @@ -40,14 +41,15 @@ import com.google.common.annotations.VisibleForTesting
case class Compliance(instance: String,
predicate: String,
where: Option[String] = None,
columns: List[String] = List.empty[String])
columns: List[String] = List.empty[String],
analyzerOptions: Option[AnalyzerOptions] = None)
extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance)
with FilterableAnalyzer {

override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {

ifNoNullsIn(result, offset, howMany = 2) { _ =>
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion))
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults))
}
}

Expand All @@ -65,6 +67,19 @@ case class Compliance(instance: String,
conditionalSelection(expr(predicate), where).cast(IntegerType)
}

private def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true).cast(IntegerType)
case _ =>
// The default behavior when using filtering for rows is to treat them as nulls. No special treatment needed.
criterion
}
}

override protected def additionalPreconditions(): Seq[StructType => Unit] =
columns.map(hasColumn)
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ object FrequencyBasedAnalyzer {
val fullColumn: Column = {
val window = Window.partitionBy(columnsToGroupBy: _*)
where.map {
condition => {
condition =>
count(when(expr(condition), UNIQUENESS_ID)).over(window)
}
}.getOrElse(count(UNIQUENESS_ID).over(window))
}

Expand Down
33 changes: 27 additions & 6 deletions src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import com.amazon.deequ.analyzers.Preconditions.isString
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.length
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType

Expand All @@ -33,12 +35,12 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
max(criterion(getNullBehavior)) :: Nil
max(criterion) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = {
ifNoNullsIn(result, offset) { _ =>
MaxState(result.getDouble(offset), Some(criterion(getNullBehavior)))
MaxState(result.getDouble(offset), Some(rowLevelResults))
}
}

Expand All @@ -48,15 +50,34 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio

override def filterCondition: Option[String] = where

private def criterion(nullBehavior: NullBehavior): Column = {
private[deequ] def criterion: Column = {
transformColForNullBehavior
}

private[deequ] def rowLevelResults: Column = {
transformColForFilteredRow(criterion)
}

private def transformColForFilteredRow(col: Column): Column = {
val whereNotCondition = where.map { expression => not(expr(expression)) }
getRowLevelFilterTreatment(analyzerOptions) match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MinValue)
case _ =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null)
}
}

private def transformColForNullBehavior: Column = {
val isNullCheck = col(column).isNull
nullBehavior match {
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
getNullBehavior match {
case NullBehavior.Fail =>
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue)
case NullBehavior.EmptyString =>
length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType)
case _ => length(conditionalSelection(column, where)).cast(DoubleType)
case _ =>
colLengths
}
}

Expand Down
18 changes: 16 additions & 2 deletions src/main/scala/com/amazon/deequ/analyzers/Maximum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not

case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MaxState] with FullColumn {
Expand All @@ -36,7 +38,7 @@ case class MaxState(maxValue: Double, override val fullColumn: Option[Column] =
}
}

case class Maximum(column: String, where: Option[String] = None)
case class Maximum(column: String, where: Option[String] = None, analyzerOptions: Option[AnalyzerOptions] = None)
extends StandardScanShareableAnalyzer[MaxState]("Maximum", column)
with FilterableAnalyzer {

Expand All @@ -47,7 +49,7 @@ case class Maximum(column: String, where: Option[String] = None)
override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = {

ifNoNullsIn(result, offset) { _ =>
MaxState(result.getDouble(offset), Some(criterion))
MaxState(result.getDouble(offset), Some(rowLevelResults))
}
}

Expand All @@ -60,5 +62,17 @@ case class Maximum(column: String, where: Option[String] = None)
@VisibleForTesting
private def criterion: Column = conditionalSelection(column, where).cast(DoubleType)

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MinValue).cast(DoubleType)
case _ =>
criterion
}
}

}

33 changes: 27 additions & 6 deletions src/main/scala/com/amazon/deequ/analyzers/MinLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import com.amazon.deequ.analyzers.Preconditions.isString
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.length
import org.apache.spark.sql.functions.min
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType

Expand All @@ -33,12 +35,12 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
min(criterion(getNullBehavior)) :: Nil
min(criterion) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = {
ifNoNullsIn(result, offset) { _ =>
MinState(result.getDouble(offset), Some(criterion(getNullBehavior)))
MinState(result.getDouble(offset), Some(rowLevelResults))
}
}

Expand All @@ -48,15 +50,34 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio

override def filterCondition: Option[String] = where

private[deequ] def criterion(nullBehavior: NullBehavior): Column = {
private[deequ] def criterion: Column = {
transformColForNullBehavior
}

private[deequ] def rowLevelResults: Column = {
transformColForFilteredRow(criterion)
}

private def transformColForFilteredRow(col: Column): Column = {
val whereNotCondition = where.map { expression => not(expr(expression)) }
getRowLevelFilterTreatment(analyzerOptions) match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MaxValue)
case _ =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null)
}
}

private def transformColForNullBehavior: Column = {
val isNullCheck = col(column).isNull
nullBehavior match {
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
getNullBehavior match {
case NullBehavior.Fail =>
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue)
case NullBehavior.EmptyString =>
length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType)
case _ => length(conditionalSelection(column, where)).cast(DoubleType)
case _ =>
colLengths
}
}

Expand Down
23 changes: 19 additions & 4 deletions src/main/scala/com/amazon/deequ/analyzers/Minimum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not

case class MinState(minValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MinState] with FullColumn {
Expand All @@ -36,7 +38,7 @@ case class MinState(minValue: Double, override val fullColumn: Option[Column] =
}
}

case class Minimum(column: String, where: Option[String] = None)
case class Minimum(column: String, where: Option[String] = None, analyzerOptions: Option[AnalyzerOptions] = None)
extends StandardScanShareableAnalyzer[MinState]("Minimum", column)
with FilterableAnalyzer {

Expand All @@ -45,9 +47,8 @@ case class Minimum(column: String, where: Option[String] = None)
}

override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = {

ifNoNullsIn(result, offset) { _ =>
MinState(result.getDouble(offset), Some(criterion))
MinState(result.getDouble(offset), Some(rowLevelResults))
}
}

Expand All @@ -58,5 +59,19 @@ case class Minimum(column: String, where: Option[String] = None)
override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = conditionalSelection(column, where).cast(DoubleType)
private def criterion: Column = {
conditionalSelection(column, where).cast(DoubleType)
}

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MaxValue).cast(DoubleType)
case _ =>
criterion
}
}
}
Loading

0 comments on commit 4830640

Please sign in to comment.