Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Wilson Score Confidence Interval Strategy #567

Merged
merged 13 commits into from
May 24, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ import com.amazon.deequ.metrics.DistributionValue
import com.amazon.deequ.profiles.ColumnProfile
import com.amazon.deequ.suggestions.ConstraintSuggestion
import com.amazon.deequ.suggestions.ConstraintSuggestionWithValue
import com.amazon.deequ.suggestions.rules.FractionalCategoricalRangeRule.defaultIntervalStrategy
import com.amazon.deequ.suggestions.rules.interval.{ConfidenceIntervalStrategy, WilsonScoreIntervalStrategy}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could we avoid grouped imports and use one import per line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify, do we prefer separate import or single import but with each on a single line

import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy
import com.amazon.deequ.suggestions.rules.interval.WilsonScoreIntervalStrategy

or

import com.amazon.deequ.suggestions.rules.interval{
  ConfidenceIntervalStrategy,
  WilsonScoreIntervalStrategy
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The former. It helps with automatic resolution of merge conflicts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

import org.apache.commons.lang3.StringEscapeUtils

import scala.math.BigDecimal.RoundingMode

/** If we see a categorical range for most values in a column, we suggest an IS IN (...)
* constraint that should hold for most values */
case class FractionalCategoricalRangeRule(
targetDataCoverageFraction: Double = 0.9,
categorySorter: Array[(String, DistributionValue)] => Array[(String, DistributionValue)] =
categories => categories.sortBy({ case (_, value) => value.absolute }).reverse
categories => categories.sortBy({ case (_, value) => value.absolute }).reverse,
intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy
) extends ConstraintRule[ColumnProfile] {

override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = {
Expand Down Expand Up @@ -79,11 +80,8 @@ case class FractionalCategoricalRangeRule(

val p = ratioSums
val n = numRecords
val z = 1.96

// TODO this needs to be more robust for p's close to 0 or 1
val targetCompliance = BigDecimal(p - z * math.sqrt(p * (1 - p) / n))
.setScale(2, RoundingMode.DOWN).toDouble
val targetCompliance = intervalStrategy.calculateTargetConfidenceInterval(p, n).lowerBound

val description = s"'${profile.column}' has value range $categoriesSql for at least " +
s"${targetCompliance * 100}% of values"
Expand Down Expand Up @@ -128,3 +126,7 @@ case class FractionalCategoricalRangeRule(
override val ruleDescription: String = "If we see a categorical range for most values " +
"in a column, we suggest an IS IN (...) constraint that should hold for most values"
}

object FractionalCategoricalRangeRule {
private val defaultIntervalStrategy: ConfidenceIntervalStrategy = WilsonScoreIntervalStrategy()
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import com.amazon.deequ.profiles.ColumnProfile
import com.amazon.deequ.suggestions.CommonConstraintSuggestion
import com.amazon.deequ.suggestions.ConstraintSuggestion
import com.amazon.deequ.suggestions.rules.RetainCompletenessRule._

import scala.math.BigDecimal.RoundingMode
import com.amazon.deequ.suggestions.rules.interval.{ConfidenceIntervalStrategy, WilsonScoreIntervalStrategy}

/**
* If a column is incomplete in the sample, we model its completeness as a binomial variable,
Expand All @@ -33,21 +32,15 @@ import scala.math.BigDecimal.RoundingMode
*/
case class RetainCompletenessRule(
minCompleteness: Double = defaultMinCompleteness,
maxCompleteness: Double = defaultMaxCompleteness
maxCompleteness: Double = defaultMaxCompleteness,
intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy
) extends ConstraintRule[ColumnProfile] {
override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = {
profile.completeness > minCompleteness && profile.completeness < maxCompleteness
}

override def candidate(profile: ColumnProfile, numRecords: Long): ConstraintSuggestion = {

val p = profile.completeness
val n = numRecords
val z = 1.96

// TODO this needs to be more robust for p's close to 0 or 1
val targetCompleteness = BigDecimal(p - z * math.sqrt(p * (1 - p) / n))
.setScale(2, RoundingMode.DOWN).toDouble
val targetCompleteness = intervalStrategy.calculateTargetConfidenceInterval(profile.completeness, numRecords).lowerBound

val constraint = completenessConstraint(profile.column, _ >= targetCompleteness)

Expand Down Expand Up @@ -75,4 +68,5 @@ case class RetainCompletenessRule(
object RetainCompletenessRule {
private val defaultMinCompleteness: Double = 0.2
private val defaultMaxCompleteness: Double = 1.0
private val defaultIntervalStrategy: ConfidenceIntervalStrategy = WilsonScoreIntervalStrategy()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.amazon.deequ.suggestions.rules.interval

import breeze.stats.distributions.{Gaussian, Rand}
import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence}

/**
* Strategy for calculate confidence interval
* */
trait ConfidenceIntervalStrategy {

/**
* Generated confidence interval interval
* @param pHat sample of the population that share a trait
* @param numRecords overall number of records
* @param confidence confidence level of method used to estimate the interval.
* @return
*/
def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval

def validateInput(pHat: Double, confidence: Double): Unit = {
require(0.0 <= pHat && pHat <= 1.0, "pHat must be between 0.0 and 1.0")
require(0.0 <= confidence && confidence <= 1.0, "confidence must be between 0.0 and 1.0")
}

def calculateZScore(confidence: Double): Double = Gaussian(0, 1)(Rand).inverseCdf(1 - ((1.0 - confidence)/ 2.0))
}

object ConfidenceIntervalStrategy {
val defaultConfidence = 0.95

case class ConfidenceInterval(lowerBound: Double, upperBound: Double)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently also calculate upperBound for these ConfidenceInterval. At the moment we don't actually make use of the upperBound though

}


Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.amazon.deequ.suggestions.rules.interval

import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence}

import scala.math.BigDecimal.RoundingMode

/**
* Implements the Wald Interval method for creating a binomial proportion confidence interval.
*
* @see <a
* href="http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Normal_approximation_interval">
* Normal approximation interval (Wikipedia)</a>
*/
case class WaldIntervalStrategy() extends ConfidenceIntervalStrategy {
def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval = {
validateInput(pHat, confidence)
val successRatio = BigDecimal(pHat)
val marginOfError = BigDecimal(calculateZScore(confidence) * math.sqrt(pHat * (1 - pHat) / numRecords))
val lowerBound = (successRatio - marginOfError).setScale(2, RoundingMode.DOWN).toDouble
val upperBound = (successRatio + marginOfError).setScale(2, RoundingMode.UP).toDouble
ConfidenceInterval(lowerBound, upperBound)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.amazon.deequ.suggestions.rules.interval

import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence}

import scala.math.BigDecimal.RoundingMode

/**
* Using Wilson score method for creating a binomial proportion confidence interval.
*
* @see <a
* href="http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Wilson_score_interval">
* Wilson score interval (Wikipedia)</a>
*/
case class WilsonScoreIntervalStrategy() extends ConfidenceIntervalStrategy {

def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval = {
validateInput(pHat, confidence)
val zScore = calculateZScore(confidence)
val zSquareOverN = math.pow(zScore, 2) / numRecords
val factor = 1.0 / (1 + zSquareOverN)
val adjustedSuccessRatio = pHat + zSquareOverN/2
val marginOfError = zScore * math.sqrt(pHat * (1 - pHat)/numRecords + zSquareOverN/(4 * numRecords))
val lowerBound = BigDecimal(factor * (adjustedSuccessRatio - marginOfError)).setScale(2, RoundingMode.DOWN).toDouble
val upperBound = BigDecimal(factor * (adjustedSuccessRatio + marginOfError)).setScale(2, RoundingMode.UP).toDouble
ConfidenceInterval(lowerBound, upperBound)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ import com.amazon.deequ.checks.{Check, CheckLevel}
import com.amazon.deequ.constraints.ConstrainableDataTypes
import com.amazon.deequ.metrics.{Distribution, DistributionValue}
import com.amazon.deequ.profiles._
import com.amazon.deequ.suggestions.rules.interval.{WaldIntervalStrategy, WilsonScoreIntervalStrategy}
import com.amazon.deequ.utils.FixtureSupport
import com.amazon.deequ.{SparkContextSpec, VerificationSuite}
import org.scalamock.scalatest.MockFactory
import org.scalatest.Inspectors.forAll
import org.scalatest.WordSpec
import org.scalatest.prop.Tables.Table

class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContextSpec
with MockFactory{
Expand Down Expand Up @@ -132,81 +135,98 @@ class ConstraintRulesTest extends WordSpec with FixtureSupport with SparkContext
val complete = StandardColumnProfile("col1", 1.0, 100, String, false, Map.empty, None)
val tenPercent = StandardColumnProfile("col1", 0.1, 100, String, false, Map.empty, None)
val incomplete = StandardColumnProfile("col1", .25, 100, String, false, Map.empty, None)
val waldIntervalStrategy = WaldIntervalStrategy()

assert(!RetainCompletenessRule().shouldBeApplied(complete, 1000))
assert(!RetainCompletenessRule(0.05, 0.9).shouldBeApplied(complete, 1000))
assert(RetainCompletenessRule(0.05, 0.9).shouldBeApplied(tenPercent, 1000))
assert(RetainCompletenessRule(0.0).shouldBeApplied(tenPercent, 1000))
assert(RetainCompletenessRule(0.0).shouldBeApplied(incomplete, 1000))
assert(RetainCompletenessRule().shouldBeApplied(incomplete, 1000))
assert(!RetainCompletenessRule(intervalStrategy = waldIntervalStrategy).shouldBeApplied(complete, 1000))
assert(!RetainCompletenessRule(0.05, 0.9, waldIntervalStrategy).shouldBeApplied(complete, 1000))
assert(RetainCompletenessRule(0.05, 0.9, waldIntervalStrategy).shouldBeApplied(tenPercent, 1000))
}

"return evaluable constraint candidates" in
withSparkSession { session =>
val table = Table(("strategy", "result"), (WaldIntervalStrategy(), true), (WilsonScoreIntervalStrategy(), true))
forAll(table) { case (strategy, result) =>
val dfWithColumnCandidate = getDfFull(session)

val dfWithColumnCandidate = getDfFull(session)
val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5)

val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5)
val check = Check(CheckLevel.Warning, "some")
.addConstraint(RetainCompletenessRule(intervalStrategy = strategy).candidate(fakeColumnProfile, 100).constraint)

val check = Check(CheckLevel.Warning, "some")
.addConstraint(RetainCompletenessRule().candidate(fakeColumnProfile, 100).constraint)
val verificationResult = VerificationSuite()
.onData(dfWithColumnCandidate)
.addCheck(check)
.run()

val verificationResult = VerificationSuite()
.onData(dfWithColumnCandidate)
.addCheck(check)
.run()
val metricResult = verificationResult.metrics.head._2

val metricResult = verificationResult.metrics.head._2
assert(metricResult.value.isSuccess == result)
}

assert(metricResult.value.isSuccess)
}

"return working code to add constraint to check" in
withSparkSession { session =>
val table = Table(
("strategy", "colCompleteness", "targetCompleteness", "result"),
(WaldIntervalStrategy(), 0.5, 0.4, true),
(WilsonScoreIntervalStrategy(), 0.4, 0.3, true)
)
forAll(table) { case (strategy, colCompleteness, targetCompleteness, result) =>

val dfWithColumnCandidate = getDfFull(session)
val dfWithColumnCandidate = getDfFull(session)

val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5)
val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", colCompleteness)

val codeForConstraint = RetainCompletenessRule().candidate(fakeColumnProfile, 100)
.codeForConstraint
val codeForConstraint = RetainCompletenessRule(intervalStrategy = strategy).candidate(fakeColumnProfile, 100)
.codeForConstraint

val expectedCodeForConstraint = """.hasCompleteness("att1", _ >= 0.4,
| Some("It should be above 0.4!"))""".stripMargin.replaceAll("\n", "")
val expectedCodeForConstraint = s""".hasCompleteness("att1", _ >= $targetCompleteness,
| Some("It should be above $targetCompleteness!"))""".stripMargin.replaceAll("\n", "")

assert(expectedCodeForConstraint == codeForConstraint)
assert(expectedCodeForConstraint == codeForConstraint)

val check = Check(CheckLevel.Warning, "some")
.hasCompleteness("att1", _ >= 0.4, Some("It should be above 0.4!"))
val check = Check(CheckLevel.Warning, "some")
.hasCompleteness("att1", _ >= targetCompleteness, Some(s"It should be above $targetCompleteness"))

val verificationResult = VerificationSuite()
.onData(dfWithColumnCandidate)
.addCheck(check)
.run()
val verificationResult = VerificationSuite()
.onData(dfWithColumnCandidate)
.addCheck(check)
.run()

val metricResult = verificationResult.metrics.head._2
val metricResult = verificationResult.metrics.head._2

assert(metricResult.value.isSuccess == result)
}

assert(metricResult.value.isSuccess)
}

"return evaluable constraint candidates with custom min/max completeness" in
withSparkSession { session =>
val table = Table(("strategy", "result"), (WaldIntervalStrategy(), true), (WilsonScoreIntervalStrategy(), true))
forAll(table) { case (strategy, result) =>
val dfWithColumnCandidate = getDfFull(session)

val dfWithColumnCandidate = getDfFull(session)

val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5)
val fakeColumnProfile = getFakeColumnProfileWithNameAndCompleteness("att1", 0.5)

val check = Check(CheckLevel.Warning, "some")
.addConstraint(RetainCompletenessRule(0.4, 0.6).candidate(fakeColumnProfile, 100).constraint)
val check = Check(CheckLevel.Warning, "some")
.addConstraint(RetainCompletenessRule(0.4, 0.6, strategy).candidate(fakeColumnProfile, 100).constraint)

val verificationResult = VerificationSuite()
.onData(dfWithColumnCandidate)
.addCheck(check)
.run()
val verificationResult = VerificationSuite()
.onData(dfWithColumnCandidate)
.addCheck(check)
.run()

val metricResult = verificationResult.metrics.head._2
val metricResult = verificationResult.metrics.head._2

assert(metricResult.value.isSuccess)
assert(metricResult.value.isSuccess == result)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.amazon.deequ.suggestions.rules.interval

import com.amazon.deequ.SparkContextSpec
import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval
import com.amazon.deequ.utils.FixtureSupport
import org.scalamock.scalatest.MockFactory
import org.scalatest.Inspectors.forAll
import org.scalatest.prop.Tables.Table
import org.scalatest.wordspec.AnyWordSpec

class IntervalStrategyTest extends AnyWordSpec with FixtureSupport with SparkContextSpec
with MockFactory {
"ConfidenceIntervalStrategy" should {
"be calculated correctly" in {
val waldStrategy = WaldIntervalStrategy()
val wilsonStrategy = WilsonScoreIntervalStrategy()
val table = Table(
("strategy", "pHat", "numRecord", "lowerBound", "upperBound"),
(waldStrategy, 1.0, 20L, 1.0, 1.0),
(waldStrategy, 0.5, 100L, 0.4, 0.6),
(waldStrategy, 0.4, 100L, 0.3, 0.5),
(waldStrategy, 0.6, 100L, 0.5, 0.7),
(waldStrategy, 0.9, 100L, 0.84, 0.96),
(waldStrategy, 1.0, 100L, 1.0, 1.0),

(wilsonStrategy, 0.01, 20L, 0.00, 0.18),
(wilsonStrategy, 1.0, 20L, 0.83, 1.0),
(wilsonStrategy, 0.5, 100L, 0.4, 0.6),
(wilsonStrategy, 0.4, 100L, 0.3, 0.5),
(wilsonStrategy, 0.6, 100L, 0.5, 0.7),
(wilsonStrategy, 0.9, 100L, 0.82, 0.95),
(wilsonStrategy, 1.0, 100L, 0.96, 1.0),
)
forAll(table) { case (strategy, pHat, numRecords, lowerBound, upperBound) =>
assert(strategy.calculateTargetConfidenceInterval(pHat, numRecords) == ConfidenceInterval(lowerBound, upperBound))
}
}
}
}
Loading