Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add Row Level Result Treatment Options for Miminum and Maximum #535

Merged
merged 7 commits into from
Feb 21, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.amazon.deequ.repository._
import org.apache.spark.sql.{DataFrame, SparkSession}

/** A class to build a VerificationRun using a fluent API */
class VerificationRunBuilder(val data: DataFrame) {
class VerificationRunBuilder(val data: DataFrame) {

protected var requiredAnalyzers: Seq[Analyzer[_, Metric[_]]] = Seq.empty

Expand Down
32 changes: 24 additions & 8 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.runners._
import com.amazon.deequ.metrics.DoubleMetric
Expand Down Expand Up @@ -172,6 +172,12 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
private[deequ] def copyStateTo(source: StateLoader, target: StatePersister): Unit = {
source.load[S](this).foreach { state => target.persist(this, state) }
}

private[deequ] def getRowLevelFilterTreatment(analyzerOptions: Option[AnalyzerOptions]): FilteredRowOutcome = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRowOutcome.TRUE)
}
}

/** An analyzer that runs a set of aggregation functions over the data,
Expand Down Expand Up @@ -257,15 +263,19 @@ case class NumMatchesAndCount(numMatches: Long, count: Long, override val fullCo
}

case class AnalyzerOptions(nullBehavior: NullBehavior = NullBehavior.Ignore,
filteredRow: FilteredRow = FilteredRow.TRUE)
filteredRow: FilteredRowOutcome = FilteredRowOutcome.TRUE)
object NullBehavior extends Enumeration {
type NullBehavior = Value
val Ignore, EmptyString, Fail = Value
}

object FilteredRow extends Enumeration {
type FilteredRow = Value
object FilteredRowOutcome extends Enumeration {
type FilteredRowOutcome = Value
val NULL, TRUE = Value

implicit class FilteredRowOutcomeOps(value: FilteredRowOutcome) {
def getExpression: Column = expr(value.toString)
}
}

/** Base class for analyzers that compute ratios of matching predicates */
Expand Down Expand Up @@ -484,6 +494,12 @@ private[deequ] object Analyzers {
.getOrElse(selection)
}

def conditionSelectionGivenColumn(selection: Column, where: Option[Column], replaceWith: Boolean): Column = {
where
.map { condition => when(condition, replaceWith).otherwise(selection) }
.getOrElse(selection)
}

def conditionalSelection(selection: Column, where: Option[String], replaceWith: Double): Column = {
conditionSelectionGivenColumn(selection, where.map(expr), replaceWith)
}
Expand All @@ -500,12 +516,12 @@ private[deequ] object Analyzers {
def conditionalSelectionFilteredFromColumns(
selection: Column,
conditionColumn: Option[Column],
filterTreatment: String)
filterTreatment: FilteredRowOutcome)
: Column = {
conditionColumn
.map { condition => {
when(not(condition), expr(filterTreatment)).when(condition, selection)
} }
.map { condition =>
when(not(condition), filterTreatment.getExpression).when(condition, selection)
}
.getOrElse(selection)
}

Expand Down
11 changes: 2 additions & 9 deletions src/main/scala/com/amazon/deequ/analyzers/Completeness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNotNested}
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.types.{IntegerType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.FilteredRow.FilteredRow
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
Expand Down Expand Up @@ -54,15 +53,9 @@ case class Completeness(column: String, where: Option[String] = None,
@VisibleForTesting // required by some tests that compare analyzer results to an expected state
private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull

@VisibleForTesting
private[deequ] def rowLevelResults: Column = {
val whereCondition = where.map { expression => expr(expression)}
conditionalSelectionFilteredFromColumns(col(column).isNotNull, whereCondition, getRowLevelFilterTreatment.toString)
}

private def getRowLevelFilterTreatment: FilteredRow = {
analyzerOptions
.map { options => options.filteredRow }
.getOrElse(FilteredRow.TRUE)
conditionalSelectionFilteredFromColumns(
col(column).isNotNull, whereCondition, getRowLevelFilterTreatment(analyzerOptions))
}
}
18 changes: 16 additions & 2 deletions src/main/scala/com/amazon/deequ/analyzers/Compliance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.functions._
import Analyzers._
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.types.DoubleType

/**
* Compliance is a measure of the fraction of rows that complies with the given column constraint.
Expand All @@ -40,14 +41,15 @@ import com.google.common.annotations.VisibleForTesting
case class Compliance(instance: String,
predicate: String,
where: Option[String] = None,
columns: List[String] = List.empty[String])
columns: List[String] = List.empty[String],
analyzerOptions: Option[AnalyzerOptions] = None)
extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance)
with FilterableAnalyzer {

override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {

ifNoNullsIn(result, offset, howMany = 2) { _ =>
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(criterion))
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults))
}
}

Expand All @@ -65,6 +67,18 @@ case class Compliance(instance: String,
conditionalSelection(expr(predicate), where).cast(IntegerType)
}

private def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true).cast(IntegerType)
case _ =>
criterion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment here that we don't need special treatment for Null because that is the default behavior anyway when using where

}
}

override protected def additionalPreconditions(): Seq[StructType => Unit] =
columns.map(hasColumn)
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ object FrequencyBasedAnalyzer {
val fullColumn: Column = {
val window = Window.partitionBy(columnsToGroupBy: _*)
where.map {
condition => {
condition =>
count(when(expr(condition), UNIQUENESS_ID)).over(window)
}
}.getOrElse(count(UNIQUENESS_ID).over(window))
}

Expand Down
37 changes: 30 additions & 7 deletions src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.amazon.deequ.analyzers.Preconditions.isString
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.length
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType

Expand All @@ -33,12 +36,12 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
max(criterion(getNullBehavior)) :: Nil
max(criterion) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = {
ifNoNullsIn(result, offset) { _ =>
MaxState(result.getDouble(offset), Some(criterion(getNullBehavior)))
MaxState(result.getDouble(offset), Some(rowLevelResults))
}
}

Expand All @@ -48,15 +51,35 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio

override def filterCondition: Option[String] = where

private def criterion(nullBehavior: NullBehavior): Column = {
val isNullCheck = col(column).isNull
private[deequ] def criterion: Column = {
transformColForNullBehavior(col(column), getNullBehavior)
}

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
transformColForFilteredRow(criterion, filteredRowOutcome)
}

private def transformColForFilteredRow(col: Column, rowOutcome: FilteredRowOutcome): Column = {
val whereNotCondition = where.map { expression => not(expr(expression)) }
rowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MinValue)
case _ =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null)
}
}

private def transformColForNullBehavior(col: Column, nullBehavior: NullBehavior): Column = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use column here directly instead of passing it in ascol ? In rowLevelResults, we use analyzerOptions directly instead of passing it in as a parameter.

val isNullCheck = col.isNull
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
nullBehavior match {
case NullBehavior.Fail =>
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue)
case NullBehavior.EmptyString =>
length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType)
case _ => length(conditionalSelection(column, where)).cast(DoubleType)
length(conditionSelectionGivenColumn(col, Option(isNullCheck), replaceWith = "")).cast(DoubleType)
case _ =>
colLengths
}
}

Expand Down
18 changes: 16 additions & 2 deletions src/main/scala/com/amazon/deequ/analyzers/Maximum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.not

case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MaxState] with FullColumn {
Expand All @@ -36,7 +38,7 @@ case class MaxState(maxValue: Double, override val fullColumn: Option[Column] =
}
}

case class Maximum(column: String, where: Option[String] = None)
case class Maximum(column: String, where: Option[String] = None, analyzerOptions: Option[AnalyzerOptions] = None)
extends StandardScanShareableAnalyzer[MaxState]("Maximum", column)
with FilterableAnalyzer {

Expand All @@ -47,7 +49,7 @@ case class Maximum(column: String, where: Option[String] = None)
override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = {

ifNoNullsIn(result, offset) { _ =>
MaxState(result.getDouble(offset), Some(criterion))
MaxState(result.getDouble(offset), Some(rowLevelResults))
}
}

Expand All @@ -60,5 +62,17 @@ case class Maximum(column: String, where: Option[String] = None)
@VisibleForTesting
private def criterion: Column = conditionalSelection(column, where).cast(DoubleType)

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col(column), whereNotCondition, replaceWith = Double.MinValue).cast(DoubleType)
case _ =>
criterion
}
}

}

37 changes: 30 additions & 7 deletions src/main/scala/com/amazon/deequ/analyzers/MinLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.FilteredRowOutcome.FilteredRowOutcome
import com.amazon.deequ.analyzers.NullBehavior.NullBehavior
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.amazon.deequ.analyzers.Preconditions.isString
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.length
import org.apache.spark.sql.functions.min
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType

Expand All @@ -33,12 +36,12 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
min(criterion(getNullBehavior)) :: Nil
min(criterion) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = {
ifNoNullsIn(result, offset) { _ =>
MinState(result.getDouble(offset), Some(criterion(getNullBehavior)))
MinState(result.getDouble(offset), Some(rowLevelResults))
}
}

Expand All @@ -48,15 +51,35 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio

override def filterCondition: Option[String] = where

private[deequ] def criterion(nullBehavior: NullBehavior): Column = {
val isNullCheck = col(column).isNull
private[deequ] def criterion: Column = {
transformColForNullBehavior(col(column), getNullBehavior)
}

private[deequ] def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
transformColForFilteredRow(criterion, filteredRowOutcome)
}

private def transformColForFilteredRow(col: Column, rowOutcome: FilteredRowOutcome): Column = {
val whereNotCondition = where.map { expression => not(expr(expression)) }
rowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MaxValue)
case _ =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null)
}
}

private def transformColForNullBehavior(col: Column, nullBehavior: NullBehavior): Column = {
val isNullCheck = col.isNull
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
nullBehavior match {
case NullBehavior.Fail =>
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue)
case NullBehavior.EmptyString =>
length(conditionSelectionGivenColumn(col(column), Option(isNullCheck), replaceWith = "")).cast(DoubleType)
case _ => length(conditionalSelection(column, where)).cast(DoubleType)
length(conditionSelectionGivenColumn(col, Option(isNullCheck), replaceWith = "")).cast(DoubleType)
case _ =>
colLengths
}
}

Expand Down
Loading
Loading